diff --git a/egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech.py b/egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech2.py similarity index 88% rename from egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech.py rename to egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech2.py index 9e0df0989..57f6d5bdf 100755 --- a/egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech.py +++ b/egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech2.py @@ -30,7 +30,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_gigaspeech(): +def compute_fbank_gigaspeech2(): in_out_dir = Path("data/fbank") # number of workers in dataloader num_workers = 20 @@ -38,7 +38,7 @@ def compute_fbank_gigaspeech(): # number of seconds in a batch batch_duration = 1000 - subsets = ("L", "M", "S", "XS", "DEV", "TEST") + subsets = ("test",) device = torch.device("cpu") if torch.cuda.is_available(): @@ -48,12 +48,12 @@ def compute_fbank_gigaspeech(): logging.info(f"device: {device}") for partition in subsets: - cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz" + cuts_path = in_out_dir / f"gigaspeech2_cuts_{partition}.jsonl.gz" if cuts_path.is_file(): logging.info(f"{cuts_path} exists - skipping") continue - raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz" + raw_cuts_path = in_out_dir / f"gigaspeech2_cuts_{partition}_raw.jsonl.gz" logging.info(f"Loading {raw_cuts_path}") cut_set = CutSet.from_file(raw_cuts_path) @@ -62,7 +62,7 @@ def compute_fbank_gigaspeech(): cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, - storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}", + storage_path=f"{in_out_dir}/gigaspeech2_feats_{partition}", num_workers=num_workers, batch_duration=batch_duration, overwrite=True, @@ -80,7 +80,7 @@ def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_gigaspeech() + compute_fbank_gigaspeech2() if __name__ == "__main__": diff --git a/egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech2_splits.py similarity index 81% rename from egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech_splits.py rename to egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech2_splits.py index 51cd59078..4c76e7eea 100755 --- a/egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech2/SSL/local/compute_fbank_gigaspeech2_splits.py @@ -37,12 +37,19 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--dataset", + type=str, + required=True, + ) + parser.add_argument( "--num-workers", type=int, default=20, help="Number of dataloading workers used for reading the audio.", ) + parser.add_argument( "--batch-duration", type=float, @@ -55,7 +62,7 @@ def get_parser(): "--num-splits", type=int, required=True, - help="The number of splits of the XL subset", + help="The number of splits of the subset", ) parser.add_argument( @@ -71,12 +78,13 @@ def get_parser(): default=-1, help="Stop processing pieces until this number (exclusive).", ) + return parser -def compute_fbank_gigaspeech_splits(args): +def compute_fbank_gigaspeech2_splits(args): num_splits = args.num_splits - output_dir = f"data/fbank/XL_split" + output_dir = f"data/fbank/{args.dataset}_split" output_dir = Path(output_dir) assert output_dir.exists(), f"{output_dir} does not exist!" @@ -99,12 +107,14 @@ def compute_fbank_gigaspeech_splits(args): idx = f"{i}".zfill(num_digits) logging.info(f"Processing {idx}/{num_splits}") - cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz" + cuts_path = output_dir / f"gigaspeech2_cuts_{args.dataset}.{idx}.jsonl.gz" if cuts_path.is_file(): logging.info(f"{cuts_path} exists - skipping") continue - raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz" + raw_cuts_path = ( + output_dir / f"gigaspeech2_cuts_{args.dataset}_raw.{idx}.jsonl.gz" + ) logging.info(f"Loading {raw_cuts_path}") cut_set = CutSet.from_file(raw_cuts_path) @@ -113,7 +123,7 @@ def compute_fbank_gigaspeech_splits(args): cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, - storage_path=f"{output_dir}/gigaspeech_feats_{idx}", + storage_path=f"{output_dir}/gigaspeech2_feats_{idx}", num_workers=args.num_workers, batch_duration=args.batch_duration, overwrite=True, @@ -130,30 +140,14 @@ def compute_fbank_gigaspeech_splits(args): def main(): - now = datetime.now() - date_time = now.strftime("%Y-%m-%d-%H-%M-%S") - - log_filename = "log-compute_fbank_gigaspeech_splits" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - log_filename = f"{log_filename}-{date_time}" - - logging.basicConfig( - filename=log_filename, - format=formatter, - level=logging.INFO, - filemode="w", - ) - - console = logging.StreamHandler() - console.setLevel(logging.INFO) - console.setFormatter(logging.Formatter(formatter)) - logging.getLogger("").addHandler(console) + logging.basicConfig(format=formatter, level=logging.INFO) parser = get_parser() args = parser.parse_args() logging.info(vars(args)) - compute_fbank_gigaspeech_splits(args) + compute_fbank_gigaspeech2_splits(args) if __name__ == "__main__": diff --git a/egs/gigaspeech2/SSL/local/preprocess_gigaspeech2.py b/egs/gigaspeech2/SSL/local/preprocess_gigaspeech2.py index b7cd0c923..d63758006 100755 --- a/egs/gigaspeech2/SSL/local/preprocess_gigaspeech2.py +++ b/egs/gigaspeech2/SSL/local/preprocess_gigaspeech2.py @@ -32,6 +32,13 @@ def get_args(): parser.add_argument( "--lang", type=str, + required=True, + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, ) return parser.parse_args() @@ -89,7 +96,7 @@ def preprocess_gigaspeech2(args): output_dir = Path("data/fbank") output_dir.mkdir(exist_ok=True) - dataset_parts = ("test",) + dataset_parts = args.dataset.strip().split(" ", -1) logging.info("Loading manifest (may take 4 minutes)") manifests = read_manifests_if_cached( diff --git a/egs/gigaspeech2/SSL/prepare.sh b/egs/gigaspeech2/SSL/prepare.sh index 42590229d..bf4858bde 100755 --- a/egs/gigaspeech2/SSL/prepare.sh +++ b/egs/gigaspeech2/SSL/prepare.sh @@ -16,6 +16,7 @@ stop_stage=5 dl_dir=$PWD/download lang=Thai +num_per_split=20000 . shared/parse_options.sh || exit 1 @@ -33,6 +34,13 @@ log "Running prepare.sh" log "dl_dir: $dl_dir" +subsets="" +for dir in ${dl_dir}/GigaSpeech2/* ; do + subset=$(basename $dir) + subsets="$subsets $subset" +done +log "Find subsets: $subsets" + if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare GigaSpeech2 manifest, language: $lang" # We assume that you have downloaded the GigaSpeech2 corpus @@ -47,16 +55,43 @@ fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "State 2: Preprocess GigaSpeech2 manifest" if [ ! -f data/fbank/.preprocess.done ]; then - python3 ./local/preprocess_gigaspeech2.py --lang $lang + python3 ./local/preprocess_gigaspeech2.py --lang $lang --dataset "$subsets" touch data/fbank/.preprocess.done fi fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute fbank for gigaspeech2" + log "Stage 3: Compute fbank for test set" mkdir -p data/fbank - if [ ! -e data/fbank/.gigaspeech2.done ]; then - ./local/compute_fbank_gigaspeech2.py - touch data/fbank/.gigaspeech2.done - fi + ./local/compute_fbank_gigaspeech2.py +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Split train set into pieces" + for subset in $subsets; do + if [[ $subset != "test" ]]; then + log "Split subset: $subset" + split_dir=data/fbank/${subset}_split + if [ ! -f $split_dir/.split.done ]; then + lhotse split-lazy ./data/fbank/gigaspeech2_cuts_${subset}_raw.jsonl.gz $split_dir $num_per_split + touch $split_dir/.split.done + fi + fi + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute features for train set" + for subset in $subsets; do + if [[ $subset != "test" ]]; then + log "Compute features for subset: $subset" + split_dir=data/fbank/${subset}_split + num_splits=$(find $split_dir -name "gigaspeech2_cuts_${subset}_raw.*.jsonl.gz" | wc -l) + python3 ./local/compute_fbank_gigaspeech2_splits.py \ + --dataset $subset \ + --num-workers 20 \ + --batch-duration 1000 \ + --num-splits $num_splits + fi + done fi