mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
deprecate params.bilingual=0, replace ReazonSpeechAsrDataModule for MultiDatasetAsrDataModule, not tested yet
This commit is contained in:
parent
b2df5bbb83
commit
eb5004880f
@ -39,7 +39,7 @@ from torch.utils.data import DataLoader
|
|||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
class ReazonSpeechAsrDataModule:
|
class MultiDatasetAsrDataModule:
|
||||||
"""
|
"""
|
||||||
DataModule for k2 ASR experiments.
|
DataModule for k2 ASR experiments.
|
||||||
It assumes there is always one train and valid dataloader,
|
It assumes there is always one train and valid dataloader,
|
||||||
@ -333,23 +333,23 @@ class ReazonSpeechAsrDataModule:
|
|||||||
)
|
)
|
||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
@lru_cache()
|
# @lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
# def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
# logging.info("About to get train cuts")
|
||||||
return load_manifest_lazy(
|
# return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
|
# self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
|
||||||
)
|
# )
|
||||||
|
|
||||||
@lru_cache()
|
# @lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
# def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
# logging.info("About to get dev cuts")
|
||||||
return load_manifest_lazy(
|
# return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
|
# self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
|
||||||
)
|
# )
|
||||||
|
|
||||||
@lru_cache()
|
# @lru_cache()
|
||||||
def test_cuts(self) -> List[CutSet]:
|
# def test_cuts(self) -> List[CutSet]:
|
||||||
logging.info("About to get test cuts")
|
# logging.info("About to get test cuts")
|
||||||
return load_manifest_lazy(
|
# return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
|
# self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
|
||||||
)
|
# )
|
||||||
|
@ -68,7 +68,7 @@ import k2
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import ReazonSpeechAsrDataModule
|
from asr_datamodule import MultiDatasetAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_nbest,
|
fast_beam_search_nbest,
|
||||||
@ -573,7 +573,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
ReazonSpeechAsrDataModule.add_arguments(parser)
|
MultiDatasetAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -748,7 +748,7 @@ def main():
|
|||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
data_module = ReazonSpeechAsrDataModule(args)
|
data_module = MultiDatasetAsrDataModule(args)
|
||||||
multi_dataset = MultiDataset(args)
|
multi_dataset = MultiDataset(args)
|
||||||
|
|
||||||
def remove_short_utt(c: Cut):
|
def remove_short_utt(c: Cut):
|
||||||
|
@ -57,7 +57,7 @@ import optim
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import ReazonSpeechAsrDataModule
|
from asr_datamodule import MultiDatasetAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -1085,8 +1085,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
|
multidataset_datamodule = MultiDatasetAsrDataModule(args)
|
||||||
train_cuts = reazonspeech_corpus.train_cuts()
|
train_cuts = multidataset_datamodule.train_cuts()
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
@ -1097,12 +1097,12 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
|
||||||
train_dl = reazonspeech_corpus.train_dataloaders(
|
train_dl = multidataset_datamodule.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = reazonspeech_corpus.valid_cuts()
|
valid_cuts = multidataset_datamodule.valid_cuts()
|
||||||
valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts)
|
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if params.start_batch <= 0 and not params.print_diagnostics:
|
if params.start_batch <= 0 and not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
@ -1242,7 +1242,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
def main():
|
def main():
|
||||||
raise RuntimeError("Please don't use this file directly!")
|
raise RuntimeError("Please don't use this file directly!")
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
ReazonSpeechAsrDataModule.add_arguments(parser)
|
MultiDatasetAsrDataModule.add_arguments(parser)
|
||||||
Tokenizer.add_arguments(parser)
|
Tokenizer.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -13,36 +13,36 @@ class MultiDataset:
|
|||||||
Args:
|
Args:
|
||||||
manifest_dir:
|
manifest_dir:
|
||||||
It is expected to contain the following files:
|
It is expected to contain the following files:
|
||||||
- reazonspeech_cuts_train.jsonl.gz
|
- mls_english/
|
||||||
- librispeech_cuts_train-clean-100.jsonl.gz
|
- mls_eng_cuts_train.jsonl.gz
|
||||||
- librispeech_cuts_train-clean-360.jsonl.gz
|
- mls_eng_cuts_dev.jsonl.gz
|
||||||
- librispeech_cuts_train-other-500.jsonl.gz
|
- mls_eng_cuts_test.jsonl.gz
|
||||||
|
- reazonspeech/
|
||||||
|
- reazonspeech_cuts_train.jsonl.gz
|
||||||
|
- reazonspeech_cuts_dev.jsonl.gz
|
||||||
|
- reazonspeech_cuts_test.jsonl.gz
|
||||||
"""
|
"""
|
||||||
self.fbank_dir = Path(args.manifest_dir)
|
self.manifest_dir = Path(args.manifest_dir)
|
||||||
|
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset train cuts")
|
logging.info("About to get multidataset train cuts")
|
||||||
|
|
||||||
logging.info("Loading Reazonspeech in lazy mode")
|
logging.info("Loading Reazonspeech TRAIN set in lazy mode")
|
||||||
reazonspeech_cuts = load_manifest_lazy(
|
reazonspeech_train_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "reazonspeech_cuts_train.jsonl.gz"
|
self.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Loading LibriSpeech in lazy mode")
|
logging.info("Loading MLS English TRAIN set in lazy mode")
|
||||||
train_clean_100_cuts = self.train_clean_100_cuts()
|
mls_eng_train_cuts = load_manifest_lazy(
|
||||||
train_clean_360_cuts = self.train_clean_360_cuts()
|
self.manifest_dir / "mls_eng_cuts_train.jsonl.gz"
|
||||||
train_other_500_cuts = self.train_other_500_cuts()
|
)
|
||||||
|
|
||||||
return CutSet.mux(
|
return CutSet.mux(
|
||||||
reazonspeech_cuts,
|
reazonspeech_train_cuts,
|
||||||
train_clean_100_cuts,
|
mls_eng_train_cuts,
|
||||||
train_clean_360_cuts,
|
|
||||||
train_other_500_cuts,
|
|
||||||
weights=[
|
weights=[
|
||||||
len(reazonspeech_cuts),
|
len(reazonspeech_train_cuts),
|
||||||
len(train_clean_100_cuts),
|
len(mls_eng_train_cuts),
|
||||||
len(train_clean_360_cuts),
|
|
||||||
len(train_other_500_cuts),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -51,93 +51,90 @@ class MultiDataset:
|
|||||||
|
|
||||||
logging.info("Loading Reazonspeech DEV set in lazy mode")
|
logging.info("Loading Reazonspeech DEV set in lazy mode")
|
||||||
reazonspeech_dev_cuts = load_manifest_lazy(
|
reazonspeech_dev_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz"
|
self.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Loading LibriSpeech DEV set in lazy mode")
|
logging.info("Loading MLS English DEV set in lazy mode")
|
||||||
dev_clean_cuts = self.dev_clean_cuts()
|
mls_eng_dev_cuts = load_manifest_lazy(
|
||||||
dev_other_cuts = self.dev_other_cuts()
|
self.manifest_dir / "mls_eng_cuts_dev.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
return CutSet.mux(
|
return CutSet.mux(
|
||||||
reazonspeech_dev_cuts,
|
reazonspeech_dev_cuts,
|
||||||
dev_clean_cuts,
|
mls_eng_dev_cuts,
|
||||||
dev_other_cuts,
|
|
||||||
weights=[
|
weights=[
|
||||||
len(reazonspeech_dev_cuts),
|
len(reazonspeech_dev_cuts),
|
||||||
len(dev_clean_cuts),
|
len(mls_eng_dev_cuts),
|
||||||
len(dev_other_cuts),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_cuts(self) -> Dict[str, CutSet]:
|
def test_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset test cuts")
|
logging.info("About to get multidataset test cuts")
|
||||||
|
|
||||||
logging.info("Loading Reazonspeech set in lazy mode")
|
logging.info("Loading Reazonspeech TEST set in lazy mode")
|
||||||
reazonspeech_test_cuts = load_manifest_lazy(
|
reazonspeech_test_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "reazonspeech_cuts_test.jsonl.gz"
|
self.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
|
||||||
)
|
|
||||||
reazonspeech_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Loading LibriSpeech set in lazy mode")
|
logging.info("Loading MLS English TEST set in lazy mode")
|
||||||
test_clean_cuts = self.test_clean_cuts()
|
mls_eng_test_cuts = load_manifest_lazy(
|
||||||
test_other_cuts = self.test_other_cuts()
|
self.manifest_dir / "mls_eng_cuts_test.jsonl.gz"
|
||||||
|
|
||||||
test_cuts = {
|
|
||||||
"reazonspeech_test": reazonspeech_test_cuts,
|
|
||||||
"reazonspeech_dev": reazonspeech_dev_cuts,
|
|
||||||
"librispeech_test_clean": test_clean_cuts,
|
|
||||||
"librispeech_test_other": test_other_cuts,
|
|
||||||
}
|
|
||||||
|
|
||||||
return test_cuts
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def train_clean_100_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get train-clean-100 cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
return CutSet.mux(
|
||||||
def train_clean_360_cuts(self) -> CutSet:
|
reazonspeech_test_cuts,
|
||||||
logging.info("About to get train-clean-360 cuts")
|
mls_eng_test_cuts,
|
||||||
return load_manifest_lazy(
|
weights=[
|
||||||
self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
len(reazonspeech_test_cuts),
|
||||||
|
len(mls_eng_test_cuts),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
# @lru_cache()
|
||||||
def train_other_500_cuts(self) -> CutSet:
|
# def train_clean_100_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train-other-500 cuts")
|
# logging.info("About to get train-clean-100 cuts")
|
||||||
return load_manifest_lazy(
|
# return load_manifest_lazy(
|
||||||
self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
# self.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||||
)
|
# )
|
||||||
|
|
||||||
@lru_cache()
|
# @lru_cache()
|
||||||
def dev_clean_cuts(self) -> CutSet:
|
# def train_clean_360_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev-clean cuts")
|
# logging.info("About to get train-clean-360 cuts")
|
||||||
return load_manifest_lazy(
|
# return load_manifest_lazy(
|
||||||
self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
# self.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
||||||
)
|
# )
|
||||||
|
|
||||||
@lru_cache()
|
# @lru_cache()
|
||||||
def dev_other_cuts(self) -> CutSet:
|
# def train_other_500_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev-other cuts")
|
# logging.info("About to get train-other-500 cuts")
|
||||||
return load_manifest_lazy(
|
# return load_manifest_lazy(
|
||||||
self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
# self.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
||||||
)
|
# )
|
||||||
|
|
||||||
@lru_cache()
|
# @lru_cache()
|
||||||
def test_clean_cuts(self) -> CutSet:
|
# def dev_clean_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test-clean cuts")
|
# logging.info("About to get dev-clean cuts")
|
||||||
return load_manifest_lazy(
|
# return load_manifest_lazy(
|
||||||
self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
# self.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||||
)
|
# )
|
||||||
|
|
||||||
@lru_cache()
|
# @lru_cache()
|
||||||
def test_other_cuts(self) -> CutSet:
|
# def dev_other_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test-other cuts")
|
# logging.info("About to get dev-other cuts")
|
||||||
return load_manifest_lazy(
|
# return load_manifest_lazy(
|
||||||
self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz"
|
# self.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
# @lru_cache()
|
||||||
|
# def test_clean_cuts(self) -> CutSet:
|
||||||
|
# logging.info("About to get test-clean cuts")
|
||||||
|
# return load_manifest_lazy(
|
||||||
|
# self.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# @lru_cache()
|
||||||
|
# def test_other_cuts(self) -> CutSet:
|
||||||
|
# logging.info("About to get test-other cuts")
|
||||||
|
# return load_manifest_lazy(
|
||||||
|
# self.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||||
|
# )
|
||||||
|
@ -63,7 +63,7 @@ import k2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
from asr_datamodule import ReazonSpeechAsrDataModule
|
from asr_datamodule import MultiDatasetAsrDataModule
|
||||||
from decode_stream import DecodeStream
|
from decode_stream import DecodeStream
|
||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
@ -740,7 +740,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
ReazonSpeechAsrDataModule.add_arguments(parser)
|
MultiDatasetAsrDataModule.add_arguments(parser)
|
||||||
Tokenizer.add_arguments(parser)
|
Tokenizer.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
@ -887,7 +887,7 @@ def main():
|
|||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
|
multidataset_datamodule = MultiDatasetAsrDataModule(args)
|
||||||
|
|
||||||
if params.bilingual:
|
if params.bilingual:
|
||||||
multi_dataset = MultiDataset(args)
|
multi_dataset = MultiDataset(args)
|
||||||
@ -904,8 +904,8 @@ def main():
|
|||||||
test_sets = test_sets_cuts.keys()
|
test_sets = test_sets_cuts.keys()
|
||||||
test_cuts = [test_sets_cuts[k] for k in test_sets]
|
test_cuts = [test_sets_cuts[k] for k in test_sets]
|
||||||
|
|
||||||
valid_cuts = reazonspeech_corpus.valid_cuts()
|
valid_cuts = multidataset_datamodule.valid_cuts()
|
||||||
test_cuts = reazonspeech_corpus.test_cuts()
|
test_cuts = multidataset_datamodule.test_cuts()
|
||||||
|
|
||||||
test_sets = ["valid", "test"]
|
test_sets = ["valid", "test"]
|
||||||
test_cuts = [valid_cuts, test_cuts]
|
test_cuts = [valid_cuts, test_cuts]
|
||||||
|
@ -66,7 +66,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import ReazonSpeechAsrDataModule
|
from asr_datamodule import MultiDatasetAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -272,7 +272,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bilingual",
|
"--bilingual",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=True,
|
||||||
help="Whether the model is bilingual or not. 1 = bilingual.",
|
help="Whether the model is bilingual or not. 1 = bilingual.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -804,7 +804,8 @@ def compute_loss(
|
|||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
if not params.bilingual:
|
if not params.bilingual:
|
||||||
y = tokenizer.encode(texts, out_type=int)
|
assert NotImplementedError("only bilingual training has been implemented")
|
||||||
|
# y = tokenizer.encode(texts, out_type=int)
|
||||||
else:
|
else:
|
||||||
y = sentencepiece_processor.encode(texts, out_type=int)
|
y = sentencepiece_processor.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
@ -1147,9 +1148,10 @@ def run(rank, world_size, args):
|
|||||||
# <blk> is defined in local/prepare_lang_char.py
|
# <blk> is defined in local/prepare_lang_char.py
|
||||||
|
|
||||||
if not params.bilingual:
|
if not params.bilingual:
|
||||||
tokenizer = Tokenizer.load(args.lang, args.lang_type)
|
assert NotImplementedError("only bilingual training has been implemented")
|
||||||
params.blank_id = tokenizer.piece_to_id("<blk>")
|
# tokenizer = Tokenizer.load(args.lang, args.lang_type)
|
||||||
params.vocab_size = tokenizer.get_piece_size()
|
# params.blank_id = tokenizer.piece_to_id("<blk>")
|
||||||
|
# params.vocab_size = tokenizer.get_piece_size()
|
||||||
else:
|
else:
|
||||||
sentencepiece_processor = spm.SentencePieceProcessor()
|
sentencepiece_processor = spm.SentencePieceProcessor()
|
||||||
sentencepiece_processor.load(params.bpe_model)
|
sentencepiece_processor.load(params.bpe_model)
|
||||||
@ -1212,12 +1214,13 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
|
multidataset_datamodule = MultiDatasetAsrDataModule(args)
|
||||||
if params.bilingual:
|
if params.bilingual:
|
||||||
multi_dataset = MultiDataset(args)
|
multi_dataset = MultiDataset(args)
|
||||||
train_cuts = multi_dataset.train_cuts()
|
train_cuts = multi_dataset.train_cuts()
|
||||||
else:
|
else:
|
||||||
train_cuts = reazonspeech_corpus.train_cuts()
|
assert NotImplementedError("only bilingual training has been implemented")
|
||||||
|
# train_cuts = reazonspeech_corpus.train_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
@ -1242,6 +1245,7 @@ def run(rank, world_size, args):
|
|||||||
# for subsampling
|
# for subsampling
|
||||||
T = ((c.num_samples - 7) // 2 + 1) // 2
|
T = ((c.num_samples - 7) // 2 + 1) // 2
|
||||||
if not params.bilingual:
|
if not params.bilingual:
|
||||||
|
assert NotImplementedError("only bilingual training has been implemented")
|
||||||
tokens = tokenizer.encode(c.supervisions[0].text, out_type=str)
|
tokens = tokenizer.encode(c.supervisions[0].text, out_type=str)
|
||||||
else:
|
else:
|
||||||
tokens = sentencepiece_processor.encode(
|
tokens = sentencepiece_processor.encode(
|
||||||
@ -1272,6 +1276,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.bilingual:
|
if params.bilingual:
|
||||||
train_cuts = train_cuts.map(tokenize_and_encode_text)
|
train_cuts = train_cuts.map(tokenize_and_encode_text)
|
||||||
|
else:
|
||||||
|
assert NotImplementedError("only bilingual training has been implemented")
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
@ -1280,15 +1286,20 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
|
||||||
train_dl = reazonspeech_corpus.train_dataloaders(
|
# train_dl = reazonspeech_corpus.train_dataloaders(
|
||||||
|
# train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
|
# )
|
||||||
|
train_dl = multidataset_datamodule.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.bilingual:
|
if params.bilingual:
|
||||||
valid_cuts = reazonspeech_corpus.valid_cuts()
|
|
||||||
else:
|
|
||||||
valid_cuts = multi_dataset.dev_cuts()
|
valid_cuts = multi_dataset.dev_cuts()
|
||||||
valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts)
|
else:
|
||||||
|
assert NotImplementedError("only bilingual training has been implemented")
|
||||||
|
# valid_cuts = multi_dataset.dev_cuts()
|
||||||
|
|
||||||
|
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
@ -1386,7 +1397,8 @@ def display_and_save_batch(
|
|||||||
if params.bilingual:
|
if params.bilingual:
|
||||||
y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
|
y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
|
||||||
else:
|
else:
|
||||||
y = tokenizer.encode(supervisions["text"], out_type=int)
|
assert NotImplementedError("only bilingual training has been implemented")
|
||||||
|
# y = tokenizer.encode(supervisions["text"], out_type=int)
|
||||||
num_tokens = sum(len(i) for i in y)
|
num_tokens = sum(len(i) for i in y)
|
||||||
logging.info(f"num tokens: {num_tokens}")
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|
||||||
@ -1442,7 +1454,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
ReazonSpeechAsrDataModule.add_arguments(parser)
|
MultiDatasetAsrDataModule.add_arguments(parser)
|
||||||
Tokenizer.add_arguments(parser)
|
Tokenizer.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user