mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Combine XL splits lazily during training.
This commit is contained in:
parent
4e05213f87
commit
f0330f9d2d
51
egs/librispeech/ASR/local/test_load_XL_split.py
Executable file
51
egs/librispeech/ASR/local/test_load_XL_split.py
Executable file
@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file can be used to check if any split is corrupted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import re
|
||||||
|
|
||||||
|
import lhotse
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
d = "data/fbank/XL_split_2000"
|
||||||
|
filenames = list(glob.glob(f"{d}/cuts_XL.*.jsonl.gz"))
|
||||||
|
|
||||||
|
pattern = re.compile(r"cuts_XL.([0-9]+).jsonl.gz")
|
||||||
|
|
||||||
|
idx_filenames = [(int(pattern.search(c).group(1)), c) for c in filenames]
|
||||||
|
|
||||||
|
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
||||||
|
|
||||||
|
print(f"Loading {len(idx_filenames)} splits")
|
||||||
|
|
||||||
|
s = 0
|
||||||
|
for i, f in idx_filenames:
|
||||||
|
cuts = lhotse.load_manifest_lazy(f)
|
||||||
|
print(i, "filename", f)
|
||||||
|
for i, c in enumerate(cuts):
|
||||||
|
s += c.features.load().shape[0]
|
||||||
|
if i > 5:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -139,11 +139,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
--batch-duration 600 \
|
--batch-duration 600 \
|
||||||
--num-splits $num_splits
|
--num-splits $num_splits
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|
||||||
log "Stage 6: Combine features for XL"
|
|
||||||
if [ ! -f data/fbank/cuts_XL.jsonl.gz ]; then
|
|
||||||
pieces=$(find data/fbank/XL_split_${num_splits} -name "cuts_XL.*.jsonl.gz")
|
|
||||||
lhotse combine $pieces data/fbank/cuts_XL.jsonl.gz
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
@ -16,9 +16,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import glob
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import lhotse
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
|
|
||||||
|
|
||||||
@ -40,9 +42,14 @@ class GigaSpeech:
|
|||||||
self.manifest_dir = Path(manifest_dir)
|
self.manifest_dir = Path(manifest_dir)
|
||||||
|
|
||||||
def train_XL_cuts(self) -> CutSet:
|
def train_XL_cuts(self) -> CutSet:
|
||||||
f = self.manifest_dir / "cuts_XL_raw.jsonl.gz"
|
logging.info("About to get train-XL cuts")
|
||||||
logging.info(f"About to get train-XL cuts from {f}")
|
|
||||||
return CutSet.from_jsonl_lazy(f)
|
filenames = list(
|
||||||
|
glob.glob(f"{self.manifest_dir}/XL_split_2000/cuts_XL.*.jsonl.gz")
|
||||||
|
)
|
||||||
|
logging.info(f"Loading {len(filenames)} splits")
|
||||||
|
|
||||||
|
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in filenames)
|
||||||
|
|
||||||
def train_L_cuts(self) -> CutSet:
|
def train_L_cuts(self) -> CutSet:
|
||||||
f = self.manifest_dir / "cuts_L_raw.jsonl.gz"
|
f = self.manifest_dir / "cuts_L_raw.jsonl.gz"
|
||||||
|
@ -986,7 +986,7 @@ def run(rank, world_size, args):
|
|||||||
giga_train_dl = asr_datamodule.train_dataloaders(
|
giga_train_dl = asr_datamodule.train_dataloaders(
|
||||||
train_giga_cuts,
|
train_giga_cuts,
|
||||||
dynamic_bucketing=True,
|
dynamic_bucketing=True,
|
||||||
on_the_fly_feats=True,
|
on_the_fly_feats=False,
|
||||||
cuts_musan=cuts_musan,
|
cuts_musan=cuts_musan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user