Merge branch 'multi_ja_en_mls_english_clean' into musan-mls-clean-final

This commit is contained in:
Bailey Machiko Hirota 2025-08-06 11:45:20 +09:00 committed by GitHub
commit 130c2a59c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 14 deletions

View File

@ -40,7 +40,7 @@ log() {
log "Starting MLS English data preparation"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
log "Stage 0: Download data"
# Check if huggingface_hub is installed
if ! python -c "import huggingface_hub" &> /dev/null; then
log "huggingface_hub Python library not found. Installing it now..."

View File

@ -1219,7 +1219,6 @@ def run(rank, world_size, args):
train_cuts = mls_english_corpus.train_cuts()
# mls_english_corpus.load_dataset(args.dataset_path)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
@ -1241,7 +1240,6 @@ def run(rank, world_size, args):
train_dl = mls_english_corpus.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_dl = mls_english_corpus.valid_dataloader()
if not params.print_diagnostics:

View File

@ -791,7 +791,7 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
y = sentencepiece_processor.encode(texts, out_type=int)
y = sentencepiece_processor.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
@ -1120,7 +1120,7 @@ def run(rank, world_size, args):
# <blk> is defined in local/prepare_lang_char.py
params.blank_id = sentencepiece_processor.piece_to_id("<blk>")
arams.vocab_size = sentencepiece_processor.get_piece_size()
params.vocab_size = sentencepiece_processor.get_piece_size()
if not params.use_transducer:
params.ctc_loss_scale = 1.0
@ -1178,22 +1178,20 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
multidataset_datamodule = MultiDatasetAsrDataModule(args)
multi_dataset = MultiDataset(args)
train_cuts = multi_dataset.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 30 seconds
#
# Caution: There is a reason to select 30.0 here. Please see
# ../local/display_manifest_statistics.py
# Keep only utterances greater than 1 second
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 30.0:
# the threshold as this is dependent on which datasets you choose
if c.duration < 1.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
@ -1244,7 +1242,8 @@ def run(rank, world_size, args):
)
valid_cuts = multi_dataset.dev_cuts()
valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts)
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
@ -1386,7 +1385,7 @@ def scan_pessimistic_batches_for_oom(
def main():
parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser)
MultiDatasetAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)