cleaned-up version of recipe

This commit is contained in:
Kinan Martin 2025-04-15 10:19:51 +09:00
parent a4be3cb3db
commit cf8e9a8a1c
6 changed files with 204 additions and 385 deletions

View File

@ -21,366 +21,230 @@ import inspect
import logging import logging
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import ( from lhotse.dataset import (
CutConcatenate, CutConcatenate,
CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import is_module_available
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class MLSEnglishHFAsrDataModule: class MLSEnglishHFAsrDataModule:
""" """
DataModule for k2 ASR experiments. DataModule for MLS English ASR experiments using HuggingFace dataset.
It assumes there is always one train and valid dataloader, Handles dataset loading and provides train/valid/test dataloaders with
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean on-the-fly feature extraction.
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
""" """
def __init__(self, args: argparse.Namespace): def __init__(self, args: argparse.Namespace):
self.args = args self.args = args
self.dataset = None
# self._validate_args()
# def _validate_args(self) -> None:
# """Validate configuration arguments."""
# if self.args.on_the_fly_feats is False:
# raise ValueError("This recipe requires on-the-fly feature extraction")
@classmethod @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser): def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group( group = parser.add_argument_group(
title="ASR data related options", title="ASR data related options",
description="These options are used for the preparation of " description="Options for data loading and processing",
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
) )
# Dataset configuration
group.add_argument( group.add_argument(
"--manifest-dir", "--dataset-path",
type=Path, type=str,
default=Path("data/manifests"), default="parler-tts/mls_eng",
help="Path to directory with train/dev/test cuts.", help="Path to HuggingFace MLS English dataset (name or local path)",
) )
# Sampling and batching
group.add_argument( group.add_argument(
"--max-duration", "--max-duration",
type=int, type=float,
default=200.0, default=200.0,
help="Maximum pooled recordings duration (seconds) in a " help="Maximum batch duration in seconds",
"single batch. You can reduce it if it causes CUDA OOM.",
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, the batches will come from buckets of " help="Whether to use bucketing sampler",
"similar duration (saves padding frames).",
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=30, default=30,
help="The number of buckets for the DynamicBucketingSampler" help="Number of buckets for DynamicBucketingSampler",
"(you might want to increase it for larger datasets).",
) )
# Data augmentation
group.add_argument( group.add_argument(
"--concatenate-cuts", "--enable-spec-aug",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=True, # must be true without lhotse feature prep
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled (=default), the examples will be " help="Whether to enable SpecAugment",
"shuffled for each epoch.",
) )
group.add_argument( group.add_argument(
"--drop-last", "--spec-aug-time-warp-factor",
type=str2bool, type=int,
default=True, default=80,
help="Whether to drop last batch. Used by sampler.", help="Time warp factor for SpecAugment",
)
# Dataloader configuration
group.add_argument(
"--num-workers",
type=int,
default=2,
help="Number of workers for data loading",
) )
group.add_argument( group.add_argument(
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
default=False, default=False,
help="When enabled, each batch will have the " help="Whether to return cuts in batch",
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
) )
group.add_argument( group.add_argument(
"--num-workers", "--drop-last",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, use SpecAugment for training dataset.", help="Whether to drop last incomplete batch",
) )
group.add_argument( return parser
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument( def load_dataset(self, dataset_path: Optional[str] = None) -> None:
"--enable-musan", """Load the HuggingFace dataset."""
type=str2bool, dataset_path = dataset_path or self.args.dataset_path
default=False, logging.info(f"Loading MLS English dataset from: {dataset_path}")
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def load_hf_dataset( try:
self, mls_eng_hf_dataset_path: str = "parler-tts/mls_eng", from datasets import load_dataset
): self.dataset = load_dataset(dataset_path)
""" logging.info("Dataset loaded successfully")
Method to load HF dataset with datasets.load_dataset except ImportError:
and save it in this DataModule.
Intended usage inside a training script:
```
mls_english_corpus = MLSEnglishHFAsrDataModule(args)
mls_english_corpus.load_hf_dataset("parler-tts/mls_eng")
train_cuts = mls_english_corpus.train_cuts()
train_dataloader = mls_english_corpus.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
...
for epoch in range(...):
train_one_epoch(
...,
train_dl=train_dl,
...,
)
```
"""
if not is_module_available("datasets"):
raise ImportError( raise ImportError(
"To process the MLS English HF corpus, please install optional dependency: pip install datasets" "Please install datasets package: pip install datasets"
) )
except Exception as e:
from datasets import load_dataset raise RuntimeError(f"Failed to load dataset: {e}")
self.dataset = load_dataset(mls_eng_hf_dataset_path) #, split="test")
def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
def _create_dataset(self, cuts: CutSet, is_train: bool = False) -> K2SpeechRecognitionDataset:
"""Create appropriate dataset with transforms."""
transforms = [] transforms = []
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if is_train and self.args.enable_spec_aug:
logging.info("Enable SpecAugment") input_transforms.append(self._create_spec_augment())
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset") return K2SpeechRecognitionDataset(
train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
if self.args.on_the_fly_feats: def _create_spec_augment(self) -> SpecAugment:
# NOTE: the PerturbSpeed transform should be added only if we """Create SpecAugment transform based on config."""
# remove it from data prep stage. num_frame_masks = 10
# Add on-the-fly speed perturbation; since originally it would num_frame_masks_parameter = inspect.signature(
# have increased epoch size by 3, we will apply prob 2/3 and use SpecAugment.__init__
# 3x more epochs. ).parameters["num_frame_masks"]
# Speed perturbation probably should come first before if num_frame_masks_parameter.default == 1:
# concatenation, but in principle the transforms order doesn't have num_frame_masks = 2
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
return SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
def _create_sampler(self, cuts: CutSet, shuffle: bool) -> Union[DynamicBucketingSampler, SimpleCutSampler]:
"""Create appropriate sampler based on config."""
if self.args.bucketing_sampler: if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.") return DynamicBucketingSampler(
train_sampler = DynamicBucketingSampler( cuts,
cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: return SimpleCutSampler(
logging.info("Using SimpleCutSampler.") cuts,
train_sampler = SimpleCutSampler( max_duration=self.args.max_duration,
cuts_train, shuffle=shuffle,
max_duration=self.args.max_duration, )
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None: def train_dataloader(self, sampler_state_dict: Optional[Dict[str, Any]] = None) -> DataLoader:
logging.info("Loading sampler state dict") """Create train dataloader."""
train_sampler.load_state_dict(sampler_state_dict) cuts = self.train_cuts()
dataset = self._create_dataset(cuts, is_train=True)
sampler = self._create_sampler(cuts, shuffle=True)
train_dl = DataLoader( if sampler_state_dict:
train, sampler.load_state_dict(sampler_state_dict)
sampler=train_sampler,
return DataLoader(
dataset,
sampler=sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=False,
) )
return train_dl def valid_dataloader(self) -> DataLoader:
"""Create validation dataloader."""
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: cuts = self.valid_cuts()
transforms = [] return DataLoader(
if self.args.concatenate_cuts: self._create_dataset(cuts),
transforms = [ sampler=self._create_sampler(cuts, shuffle=False),
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None, batch_size=None,
num_workers=2, num_workers=2,
persistent_workers=False, persistent_workers=False,
) )
return valid_dl def test_dataloader(self) -> DataLoader:
"""Create test dataloader."""
def test_dataloaders(self, cuts: CutSet) -> DataLoader: cuts = self.test_cuts()
logging.info("About to create test dataset") return DataLoader(
test = K2SpeechRecognitionDataset( self._create_dataset(cuts),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) sampler=self._create_sampler(cuts, shuffle=False),
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
test_dl = DataLoader(
test,
batch_size=None, batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
) )
return test_dl
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") return CutSet.from_huggingface_dataset(
cutset = CutSet.from_huggingface_dataset(self.dataset["train"], text_key="transcript") self.dataset["train"],
return cutset text_key="transcript"
)
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") return CutSet.from_huggingface_dataset(
cutset = CutSet.from_huggingface_dataset(self.dataset["dev"], text_key="transcript") self.dataset["dev"],
return cutset text_key="transcript"
)
@lru_cache() @lru_cache()
def test_cuts(self) -> List[CutSet]: def test_cuts(self) -> CutSet:
logging.info("About to get test cuts") return CutSet.from_huggingface_dataset(
cutset = CutSet.from_huggingface_dataset(self.dataset["test"], text_key="transcript") self.dataset["test"],
return cutset text_key="transcript"
)

View File

@ -19,59 +19,71 @@
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Optional
from lhotse import CutSet from lhotse import CutSet
from asr_datamodule import MLSEnglishHFAsrDataModule
from tqdm import tqdm from tqdm import tqdm
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Generate transcripts for BPE training from MLS English dataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
# parser.add_argument( parser.add_argument(
# "train_cut", metavar="train-cut", type=Path, help="Path to the train cut" "--dataset-path",
# ) type=str,
default="parler-tts/mls_eng",
help="Path to HuggingFace MLS English dataset (name or local path)",
)
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=Path, type=Path,
default=Path("data/lang"), default=Path("data/lang"),
help=( help="Directory to store output transcripts",
"Name of lang dir. " )
"If not set, this will default to data/lang"
), parser.add_argument(
"--split",
type=str,
default="train",
help="Dataset split to use for generating transcripts (train/dev/test)",
) )
return parser.parse_args() return parser.parse_args()
def generate_transcript_from_cuts(cuts: CutSet, output_file: Path) -> None:
"""Generate transcript text file from Lhotse CutSet."""
with open(output_file, "w") as f:
for cut in tqdm(cuts, desc="Processing cuts"):
for sup in cut.supervisions:
f.write(f"{sup.text}\n")
def main(): def main():
args = get_args() args = get_args()
logging.basicConfig( logging.basicConfig(
format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
level=logging.INFO, level=logging.INFO,
) )
args.lang_dir.mkdir(parents=True, exist_ok=True)
output_file = args.lang_dir / "transcript.txt"
mls_english_corpus = MLSEnglishHFAsrDataModule(args) logging.info(f"Loading {args.split} split from dataset: {args.dataset_path}")
mls_english_corpus.load_hf_dataset("/root/datasets/parler-tts--mls_eng") try:
cuts = CutSet.from_huggingface_dataset(
train_cuts = mls_english_corpus.train_cuts() args.dataset_path,
split=args.split,
logging.info(f"Creating transcript from MLS English train cut.") text_key="transcript"
)
def generate_text(train_cuts): except Exception as e:
for cut in tqdm(train_cuts): logging.error(f"Failed to load dataset: {e}")
for sup in cut.supervisions: raise
yield sup.text + "\n"
with open(args.lang_dir / "transcript.txt", "w") as file:
file.writelines(generate_text(train_cuts))
logging.info("Done.")
logging.info(f"Generating transcript to {output_file}")
generate_transcript_from_cuts(cuts, output_file)
logging.info("Transcript generation completed")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,5 +1,9 @@
#!/usr/bin/env bash #!/usr/bin/env bash
# Prepare script for MLS English ASR recipe in icefall
# This recipe uses on-the-fly feature extraction, so it skips manifest
# and feature generation steps used in other recipes.
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
@ -9,118 +13,50 @@ nj=15
stage=-1 stage=-1
stop_stage=100 stop_stage=100
# vocab_sizes=(500 1000 2000) # Configuration for BPE tokenizer
vocab_sizes=(2000) vocab_sizes=(2000) # You can add more sizes like (500 1000 2000) for comparison
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/ReazonSpeech
# You can find FLAC files in this directory.
# You can download them from https://huggingface.co/datasets/reazon-research/reazonspeech
#
# - $dl_dir/dataset.json
# The metadata of the ReazonSpeech dataset.
# Directory where dataset will be downloaded
dl_dir=$PWD/download dl_dir=$PWD/download
. shared/parse_options.sh || exit 1 . shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data". # All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data mkdir -p data
log() { log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/} local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
} }
log "Running prepare.sh" log "Starting MLS English data preparation"
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data" log "Stage 0: Download MLS English dataset"
# If you have pre-downloaded it to /path/to/mls_eng,
# you can create a symlink
#
# ln -sfv /path/to/mls_eng $dl_dir/mls_eng
#
if [ ! -d $dl_dir/mls_english ]; then if [ ! -d $dl_dir/mls_english ]; then
git clone https://huggingface.co/datasets/parler-tts/mls_eng $dl_dir/mls_eng if ! git clone https://huggingface.co/datasets/parler-tts/mls_eng $dl_dir/mls_english; then
log "Failed to download MLS English dataset"
exit 1
fi
fi fi
fi fi
## Not necessary to create manifest or pre-compute fbank for on-the-fly feature computation ##
# if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# log "Stage 1: Prepare MLS English manifest"
# # We assume that you have downloaded the ReazonSpeech corpus
# # to $dl_dir/ReazonSpeech
# mkdir -p data/manifests
# if [ ! -e data/manifests/.reazonspeech.done ]; then
# lhotse prepare reazonspeech -j $nj $dl_dir/ReazonSpeech data/manifests
# touch data/manifests/.reazonspeech.done
# fi
# fi
# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# log "Stage 2: Compute ReazonSpeech fbank"
# if [ ! -e data/manifests/.reazonspeech-validated.done ]; then
# python local/compute_fbank_reazonspeech.py --manifest-dir data/manifests
# python local/validate_manifest.py --manifest data/manifests/reazonspeech_cuts_train.jsonl.gz
# python local/validate_manifest.py --manifest data/manifests/reazonspeech_cuts_dev.jsonl.gz
# python local/validate_manifest.py --manifest data/manifests/reazonspeech_cuts_test.jsonl.gz
# touch data/manifests/.reazonspeech-validated.done
# fi
# fi
###############################################################################################
# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
# log "Stage 3: Prepare ReazonSpeech lang_char"
# python local/prepare_lang_char.py data/manifests/reazonspeech_cuts_train.jsonl.gz
# fi
# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
# log "Stage 4: Show manifest statistics"
# python local/display_manifest_statistics.py --manifest-dir data/manifests > data/manifests/manifest_statistics.txt
# cat data/manifests/manifest_statistics.txt
# fi
mkdir -p data/lang mkdir -p data/lang
lang_dir=data/lang lang_dir=data/lang
log "lang_dir: $lang_dir"
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare BPE based lang" log "Stage 1: Prepare BPE tokenizer"
if [ ! -f $lang_dir/transcript.txt ]; then if [ ! -f $lang_dir/transcript.txt ]; then
log "Generate transcript for BPE training" log "Generating transcripts for BPE training"
./local/utils/generate_transcript.py --lang-dir $lang_dir ./local/utils/generate_transcript.py --lang-dir $lang_dir
# files=$(
# find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
# find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
# find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
# )
# for f in ${files[@]}; do
# cat $f | cut -d " " -f 2-
# done > $lang_dir/transcript_words.txt
fi fi
for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
log "Train BPE model with vocab_size: $vocab_size" log "Training BPE model with vocab_size=${vocab_size}"
bpe_dir=data/lang/bpe_${vocab_size} bpe_dir=data/lang/bpe_${vocab_size}
mkdir -p $bpe_dir mkdir -p $bpe_dir
if [ ! -f $bpe_dir/bpe.model ]; then if [ ! -f $bpe_dir/bpe.model ]; then
./local/train_bpe_model.py \ ./local/train_bpe_model.py \
--lang-dir $bpe_dir \ --lang-dir $bpe_dir \
@ -128,4 +64,6 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
--transcript $lang_dir/transcript.txt --transcript $lang_dir/transcript.txt
fi fi
done done
fi fi
log "MLS English data preparation completed successfully"

View File

@ -1 +1 @@
local/utils/asr_datamodule.py ../local/utils/asr_datamodule.py

View File

@ -1043,13 +1043,13 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
mls_english_corpus = MLSEnglishHFAsrDataModule(args) mls_english_corpus = MLSEnglishHFAsrDataModule(args)
mls_english_corpus.load_hf_dataset("/root/datasets/parler-tts--mls_eng") mls_english_corpus.load_dataset(args.dataset_path)
# dev_cuts = mls_english_corpus.dev_cuts() # # dev_cuts = mls_english_corpus.dev_cuts()
test_cuts = mls_english_corpus.test_cuts() # test_cuts = mls_english_corpus.test_cuts()
# dev_dl = mls_english_corpus.test_dataloaders(dev_cuts) # dev_dl = mls_english_corpus.test_dataloader()
test_dl = mls_english_corpus.test_dataloaders(test_cuts) test_dl = mls_english_corpus.test_dataloader()
test_sets = ["test"] test_sets = ["test"]
test_dls = [test_dl] test_dls = [test_dl]

View File

@ -1215,9 +1215,9 @@ def run(rank, world_size, args):
return True return True
mls_english_corpus = MLSEnglishHFAsrDataModule(args) mls_english_corpus = MLSEnglishHFAsrDataModule(args)
mls_english_corpus.load_hf_dataset("/root/datasets/parler-tts--mls_eng") mls_english_corpus.load_dataset(args.dataset_path)
train_cuts = mls_english_corpus.train_cuts() # train_cuts = mls_english_corpus.train_cuts()
# train_cuts = train_cuts.filter(remove_short_and_long_utt) # train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1228,12 +1228,17 @@ def run(rank, world_size, args):
else: else:
sampler_state_dict = None sampler_state_dict = None
train_dl = mls_english_corpus.train_dataloaders( # train_dl = mls_english_corpus.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict # train_cuts, sampler_state_dict=sampler_state_dict
# )
train_dl = mls_english_corpus.train_dataloader(
sampler_state_dict=sampler_state_dict
) )
valid_cuts = mls_english_corpus.valid_cuts() # valid_cuts = mls_english_corpus.valid_cuts()
valid_dl = mls_english_corpus.valid_dataloaders(valid_cuts) # valid_dl = mls_english_corpus.valid_dataloader(valid_cuts)
valid_dl = mls_english_corpus.valid_dataloader()
if not params.print_diagnostics: if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(