mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 17:44:20 +00:00
Use jsonl for cutsets in the librispeech recipe.
This commit is contained in:
parent
8a3068ead8
commit
d93512344b
@ -99,7 +99,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/tmp/fbank-libri
|
~/tmp/fbank-libri
|
||||||
key: cache-libri-fbank-test-clean-and-test-other
|
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||||
|
|
||||||
- name: Compute fbank for LibriSpeech test-clean and test-other
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
@ -99,7 +99,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/tmp/fbank-libri
|
~/tmp/fbank-libri
|
||||||
key: cache-libri-fbank-test-clean-and-test-other
|
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||||
|
|
||||||
- name: Compute fbank for LibriSpeech test-clean and test-other
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
@ -99,7 +99,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/tmp/fbank-libri
|
~/tmp/fbank-libri
|
||||||
key: cache-libri-fbank-test-clean-and-test-other
|
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||||
|
|
||||||
- name: Compute fbank for LibriSpeech test-clean and test-other
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
@ -99,7 +99,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/tmp/fbank-libri
|
~/tmp/fbank-libri
|
||||||
key: cache-libri-fbank-test-clean-and-test-other
|
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||||
|
|
||||||
- name: Compute fbank for LibriSpeech test-clean and test-other
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
@ -99,7 +99,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/tmp/fbank-libri
|
~/tmp/fbank-libri
|
||||||
key: cache-libri-fbank-test-clean-and-test-other
|
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||||
|
|
||||||
- name: Compute fbank for LibriSpeech test-clean and test-other
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
@ -98,7 +98,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/tmp/fbank-libri
|
~/tmp/fbank-libri
|
||||||
key: cache-libri-fbank-test-clean-and-test-other
|
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||||
|
|
||||||
- name: Compute fbank for LibriSpeech test-clean and test-other
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
@ -98,7 +98,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/tmp/fbank-libri
|
~/tmp/fbank-libri
|
||||||
key: cache-libri-fbank-test-clean-and-test-other
|
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||||
|
|
||||||
- name: Compute fbank for LibriSpeech test-clean and test-other
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
@ -98,7 +98,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/tmp/fbank-libri
|
~/tmp/fbank-libri
|
||||||
key: cache-libri-fbank-test-clean-and-test-other
|
key: cache-libri-fbank-test-clean-and-test-other-v2
|
||||||
|
|
||||||
- name: Compute fbank for LibriSpeech test-clean and test-other
|
- name: Compute fbank for LibriSpeech test-clean and test-other
|
||||||
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
|
||||||
|
@ -17,6 +17,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
./conformer_ctc/train.py \
|
||||||
|
--exp-dir ./conformer_ctc/exp \
|
||||||
|
--world-size 4 \
|
||||||
|
--full-libri 1 \
|
||||||
|
--max-duration 200 \
|
||||||
|
--num-epochs 20
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -29,6 +40,7 @@ import torch.multiprocessing as mp
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@ -676,6 +688,18 @@ def run(rank, world_size, args):
|
|||||||
if params.full_libri:
|
if params.full_libri:
|
||||||
train_cuts += librispeech.train_clean_360_cuts()
|
train_cuts += librispeech.train_clean_360_cuts()
|
||||||
train_cuts += librispeech.train_other_500_cuts()
|
train_cuts += librispeech.train_other_500_cuts()
|
||||||
|
|
||||||
|
def remove_short_and_long_utt(c: Cut):
|
||||||
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
|
#
|
||||||
|
# Caution: There is a reason to select 20.0 here. Please see
|
||||||
|
# ../local/display_manifest_statistics.py
|
||||||
|
#
|
||||||
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
|
# an utterance duration distribution for your dataset to select
|
||||||
|
# the threshold
|
||||||
|
return 1.0 <= c.duration <= 20.0
|
||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
|
@ -28,7 +28,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -56,8 +56,13 @@ def compute_fbank_librispeech():
|
|||||||
"train-clean-360",
|
"train-clean-360",
|
||||||
"train-other-500",
|
"train-other-500",
|
||||||
)
|
)
|
||||||
|
prefix = "librispeech"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
manifests = read_manifests_if_cached(
|
manifests = read_manifests_if_cached(
|
||||||
prefix="librispeech", dataset_parts=dataset_parts, output_dir=src_dir
|
dataset_parts=dataset_parts,
|
||||||
|
output_dir=src_dir,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
@ -65,7 +70,8 @@ def compute_fbank_librispeech():
|
|||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
if (output_dir / f"cuts_{partition}.json.gz").is_file():
|
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||||
|
if (output_dir / cuts_filename).is_file():
|
||||||
logging.info(f"{partition} already exists - skipping.")
|
logging.info(f"{partition} already exists - skipping.")
|
||||||
continue
|
continue
|
||||||
logging.info(f"Processing {partition}")
|
logging.info(f"Processing {partition}")
|
||||||
@ -81,13 +87,13 @@ def compute_fbank_librispeech():
|
|||||||
)
|
)
|
||||||
cut_set = cut_set.compute_and_store_features(
|
cut_set = cut_set.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=f"{output_dir}/feats_{partition}",
|
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||||
# when an executor is specified, make more partitions
|
# when an executor is specified, make more partitions
|
||||||
num_jobs=num_jobs if ex is None else 80,
|
num_jobs=num_jobs if ex is None else 80,
|
||||||
executor=ex,
|
executor=ex,
|
||||||
storage_type=ChunkedLilcomHdf5Writer,
|
storage_type=LilcomChunkyWriter,
|
||||||
)
|
)
|
||||||
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
|
cut_set.to_file(output_dir / cuts_filename)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -28,7 +28,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig, combine
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -52,12 +52,22 @@ def compute_fbank_musan():
|
|||||||
"speech",
|
"speech",
|
||||||
"noise",
|
"noise",
|
||||||
)
|
)
|
||||||
|
prefix = "musan"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
manifests = read_manifests_if_cached(
|
manifests = read_manifests_if_cached(
|
||||||
prefix="musan", dataset_parts=dataset_parts, output_dir=src_dir
|
dataset_parts=dataset_parts,
|
||||||
|
output_dir=src_dir,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
musan_cuts_path = output_dir / "cuts_musan.json.gz"
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
)
|
||||||
|
|
||||||
|
musan_cuts_path = output_dir / "musan_cuts.jsonl.gz"
|
||||||
|
|
||||||
if musan_cuts_path.is_file():
|
if musan_cuts_path.is_file():
|
||||||
logging.info(f"{musan_cuts_path} already exists - skipping")
|
logging.info(f"{musan_cuts_path} already exists - skipping")
|
||||||
@ -79,13 +89,13 @@ def compute_fbank_musan():
|
|||||||
.filter(lambda c: c.duration > 5)
|
.filter(lambda c: c.duration > 5)
|
||||||
.compute_and_store_features(
|
.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=f"{output_dir}/feats_musan",
|
storage_path=f"{output_dir}/musan_feats",
|
||||||
num_jobs=num_jobs if ex is None else 80,
|
num_jobs=num_jobs if ex is None else 80,
|
||||||
executor=ex,
|
executor=ex,
|
||||||
storage_type=ChunkedLilcomHdf5Writer,
|
storage_type=LilcomChunkyWriter,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
musan_cuts.to_json(musan_cuts_path)
|
musan_cuts.to_file(musan_cuts_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -25,7 +25,7 @@ We will add more checks later if needed.
|
|||||||
Usage example:
|
Usage example:
|
||||||
|
|
||||||
python3 ./local/validate_manifest.py \
|
python3 ./local/validate_manifest.py \
|
||||||
./data/fbank/cuts_train-clean-100.json.gz
|
./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lhotse import load_manifest, CutSet
|
from lhotse import CutSet, load_manifest
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,9 +40,9 @@ dl_dir=$PWD/download
|
|||||||
# It will generate data/lang_bpe_xxx,
|
# It will generate data/lang_bpe_xxx,
|
||||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||||
vocab_sizes=(
|
vocab_sizes=(
|
||||||
5000
|
# 5000
|
||||||
2000
|
# 2000
|
||||||
1000
|
# 1000
|
||||||
500
|
500
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -132,7 +132,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
)
|
)
|
||||||
for part in ${parts[@]}; do
|
for part in ${parts[@]}; do
|
||||||
python3 ./local/validate_manifest.py \
|
python3 ./local/validate_manifest.py \
|
||||||
data/fbank/cuts_${part}.json.gz
|
data/fbank/librispeech_cuts_${part}.jsonl.gz
|
||||||
done
|
done
|
||||||
touch data/fbank/.librispeech-validated.done
|
touch data/fbank/.librispeech-validated.done
|
||||||
fi
|
fi
|
||||||
|
@ -807,28 +807,8 @@ def run(rank, world_size, args):
|
|||||||
# the threshold
|
# the threshold
|
||||||
return 1.0 <= c.duration <= 20.0
|
return 1.0 <= c.duration <= 20.0
|
||||||
|
|
||||||
num_in_total = len(train_cuts)
|
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
try:
|
|
||||||
num_left = len(train_cuts)
|
|
||||||
num_removed = num_in_total - num_left
|
|
||||||
removed_percent = num_removed / num_in_total * 100
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"Before removing short and long utterances: {num_in_total}"
|
|
||||||
)
|
|
||||||
logging.info(f"After removing short and long utterances: {num_left}")
|
|
||||||
logging.info(
|
|
||||||
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
|
|
||||||
)
|
|
||||||
except TypeError as e:
|
|
||||||
# You can ignore this error as previous versions of Lhotse work fine
|
|
||||||
# for the above code. In recent versions of Lhotse, it uses
|
|
||||||
# lazy filter, producing cutsets that don't have the __len__ method
|
|
||||||
logging.info(str(e))
|
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
# saved in the middle of an epoch
|
# saved in the middle of an epoch
|
||||||
|
@ -225,7 +225,7 @@ class LibriSpeechAsrDataModule:
|
|||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(
|
cuts_musan = load_manifest(
|
||||||
self.args.manifest_dir / "cuts_musan.json.gz"
|
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||||
)
|
)
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(
|
CutMix(
|
||||||
@ -408,39 +408,47 @@ class LibriSpeechAsrDataModule:
|
|||||||
def train_clean_100_cuts(self) -> CutSet:
|
def train_clean_100_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train-clean-100 cuts")
|
logging.info("About to get train-clean-100 cuts")
|
||||||
return load_manifest(
|
return load_manifest(
|
||||||
self.args.manifest_dir / "cuts_train-clean-100.json.gz"
|
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_clean_360_cuts(self) -> CutSet:
|
def train_clean_360_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train-clean-360 cuts")
|
logging.info("About to get train-clean-360 cuts")
|
||||||
return load_manifest(
|
return load_manifest(
|
||||||
self.args.manifest_dir / "cuts_train-clean-360.json.gz"
|
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_other_500_cuts(self) -> CutSet:
|
def train_other_500_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train-other-500 cuts")
|
logging.info("About to get train-other-500 cuts")
|
||||||
return load_manifest(
|
return load_manifest(
|
||||||
self.args.manifest_dir / "cuts_train-other-500.json.gz"
|
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_clean_cuts(self) -> CutSet:
|
def dev_clean_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev-clean cuts")
|
logging.info("About to get dev-clean cuts")
|
||||||
return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz")
|
return load_manifest(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_other_cuts(self) -> CutSet:
|
def dev_other_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev-other cuts")
|
logging.info("About to get dev-other cuts")
|
||||||
return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz")
|
return load_manifest(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_clean_cuts(self) -> CutSet:
|
def test_clean_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test-clean cuts")
|
logging.info("About to get test-clean cuts")
|
||||||
return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz")
|
return load_manifest(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_other_cuts(self) -> CutSet:
|
def test_other_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test-other cuts")
|
logging.info("About to get test-other cuts")
|
||||||
return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz")
|
return load_manifest(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||||
|
)
|
||||||
|
@ -16,6 +16,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
./tdnn_lstm_ctc/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--full-libri 1 \
|
||||||
|
--max-duration 300 \
|
||||||
|
--num-epochs 20
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@ -29,6 +38,7 @@ import torch.multiprocessing as mp
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import TdnnLstm
|
from model import TdnnLstm
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -544,10 +554,25 @@ def run(rank, world_size, args):
|
|||||||
if params.full_libri:
|
if params.full_libri:
|
||||||
train_cuts += librispeech.train_clean_360_cuts()
|
train_cuts += librispeech.train_clean_360_cuts()
|
||||||
train_cuts += librispeech.train_other_500_cuts()
|
train_cuts += librispeech.train_other_500_cuts()
|
||||||
|
|
||||||
|
def remove_short_and_long_utt(c: Cut):
|
||||||
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
|
#
|
||||||
|
# Caution: There is a reason to select 20.0 here. Please see
|
||||||
|
# ../local/display_manifest_statistics.py
|
||||||
|
#
|
||||||
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
|
# an utterance duration distribution for your dataset to select
|
||||||
|
# the threshold
|
||||||
|
return 1.0 <= c.duration <= 20.0
|
||||||
|
|
||||||
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
|
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user