mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-10 22:45:27 +00:00
Merge branch 'multi_ja_en_mls_english_clean' into musan-mls-clean-final
This commit is contained in:
commit
130c2a59c3
@ -40,7 +40,7 @@ log() {
|
|||||||
log "Starting MLS English data preparation"
|
log "Starting MLS English data preparation"
|
||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
log "Stage 0: Download data"
|
log "Stage 0: Download data"
|
||||||
# Check if huggingface_hub is installed
|
# Check if huggingface_hub is installed
|
||||||
if ! python -c "import huggingface_hub" &> /dev/null; then
|
if ! python -c "import huggingface_hub" &> /dev/null; then
|
||||||
log "huggingface_hub Python library not found. Installing it now..."
|
log "huggingface_hub Python library not found. Installing it now..."
|
||||||
|
|||||||
@ -1219,7 +1219,6 @@ def run(rank, world_size, args):
|
|||||||
train_cuts = mls_english_corpus.train_cuts()
|
train_cuts = mls_english_corpus.train_cuts()
|
||||||
# mls_english_corpus.load_dataset(args.dataset_path)
|
# mls_english_corpus.load_dataset(args.dataset_path)
|
||||||
|
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
# saved in the middle of an epoch
|
# saved in the middle of an epoch
|
||||||
@ -1241,7 +1240,6 @@ def run(rank, world_size, args):
|
|||||||
train_dl = mls_english_corpus.train_dataloaders(
|
train_dl = mls_english_corpus.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_dl = mls_english_corpus.valid_dataloader()
|
valid_dl = mls_english_corpus.valid_dataloader()
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
|
|||||||
@ -791,7 +791,7 @@ def compute_loss(
|
|||||||
warm_step = params.warm_step
|
warm_step = params.warm_step
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
y = sentencepiece_processor.encode(texts, out_type=int)
|
y = sentencepiece_processor.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
@ -1120,7 +1120,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
# <blk> is defined in local/prepare_lang_char.py
|
# <blk> is defined in local/prepare_lang_char.py
|
||||||
params.blank_id = sentencepiece_processor.piece_to_id("<blk>")
|
params.blank_id = sentencepiece_processor.piece_to_id("<blk>")
|
||||||
arams.vocab_size = sentencepiece_processor.get_piece_size()
|
params.vocab_size = sentencepiece_processor.get_piece_size()
|
||||||
|
|
||||||
if not params.use_transducer:
|
if not params.use_transducer:
|
||||||
params.ctc_loss_scale = 1.0
|
params.ctc_loss_scale = 1.0
|
||||||
@ -1178,22 +1178,20 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
|
multidataset_datamodule = MultiDatasetAsrDataModule(args)
|
||||||
|
|
||||||
multi_dataset = MultiDataset(args)
|
multi_dataset = MultiDataset(args)
|
||||||
|
|
||||||
train_cuts = multi_dataset.train_cuts()
|
train_cuts = multi_dataset.train_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 30 seconds
|
|
||||||
#
|
# Keep only utterances greater than 1 second
|
||||||
# Caution: There is a reason to select 30.0 here. Please see
|
|
||||||
# ../local/display_manifest_statistics.py
|
|
||||||
#
|
#
|
||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold as this is dependent on which datasets you choose
|
||||||
if c.duration < 1.0 or c.duration > 30.0:
|
if c.duration < 1.0:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
)
|
)
|
||||||
@ -1244,7 +1242,8 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = multi_dataset.dev_cuts()
|
valid_cuts = multi_dataset.dev_cuts()
|
||||||
valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts)
|
|
||||||
|
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
@ -1386,7 +1385,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
ReazonSpeechAsrDataModule.add_arguments(parser)
|
MultiDatasetAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user