This commit is contained in:
JinZr 2023-07-21 01:54:35 +08:00
parent 748db76648
commit 2c8d65904c
2 changed files with 23 additions and 19 deletions

View File

@ -216,14 +216,14 @@ class MultiDataset:
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
)
return [
aidatatang_dev_cuts,
aishell_dev_cuts,
aishell2_dev_cuts,
alimeeting_dev_cuts,
magicdata_dev_cuts,
kespeech_dev_phase1_cuts,
kespeech_dev_phase2_cuts,
wenetspeech_dev_cuts,
]
return wenetspeech_dev_cuts
# return [
# aidatatang_dev_cuts,
# aishell_dev_cuts,
# aishell2_dev_cuts,
# alimeeting_dev_cuts,
# magicdata_dev_cuts,
# kespeech_dev_phase1_cuts,
# kespeech_dev_phase2_cuts,
# wenetspeech_dev_cuts,
# ]

View File

@ -72,6 +72,7 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import AsrModel
from multi_dataset import MultiDataset
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
@ -324,7 +325,7 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
default="data/lang_bpe_2000/bpe.model",
help="Path to the BPE model",
)
@ -1174,11 +1175,13 @@ def run(rank, world_size, args):
register_inf_check_hooks(model)
librispeech = LibriSpeechAsrDataModule(args)
multi_dataset = MultiDataset(args.manifest_dir)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_cuts = multi_dataset.train_cuts()
# train_cuts = librispeech.train_clean_100_cuts()
# if params.full_libri:
# train_cuts += librispeech.train_clean_360_cuts()
# train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -1189,7 +1192,7 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 20.0:
if c.duration < 1.0 or c.duration > 600.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
@ -1230,8 +1233,9 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
# valid_cuts = librispeech.dev_clean_cuts()
# valid_cuts += librispeech.dev_other_cuts()
valid_cuts = multi_dataset.dev_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: