mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
do some changes and text normalize
This commit is contained in:
parent
4215ec434a
commit
ddc55423b1
@ -29,7 +29,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -53,11 +53,13 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
|||||||
"train_L",
|
"train_L",
|
||||||
"test",
|
"test",
|
||||||
)
|
)
|
||||||
|
prefix = "aishell4"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
manifests = read_manifests_if_cached(
|
manifests = read_manifests_if_cached(
|
||||||
dataset_parts=dataset_parts,
|
dataset_parts=dataset_parts,
|
||||||
output_dir=src_dir,
|
output_dir=src_dir,
|
||||||
prefix="aishell4",
|
prefix=prefix,
|
||||||
suffix="jsonl.gz",
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
@ -65,7 +67,8 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
|||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
if (output_dir / f"cuts_{partition}.jsonl").is_file():
|
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||||
|
if (output_dir / cuts_filename).is_file():
|
||||||
logging.info(f"{partition} already exists - skipping.")
|
logging.info(f"{partition} already exists - skipping.")
|
||||||
continue
|
continue
|
||||||
logging.info(f"Processing {partition}")
|
logging.info(f"Processing {partition}")
|
||||||
@ -81,11 +84,11 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
|||||||
)
|
)
|
||||||
cut_set = cut_set.compute_and_store_features(
|
cut_set = cut_set.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=f"{output_dir}/feats_{partition}",
|
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||||
# when an executor is specified, make more partitions
|
# when an executor is specified, make more partitions
|
||||||
num_jobs=num_jobs if ex is None else 80,
|
num_jobs=num_jobs if ex is None else 80,
|
||||||
executor=ex,
|
executor=ex,
|
||||||
storage_type=ChunkedLilcomHdf5Writer,
|
storage_type=LilcomChunkyWriter,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("About splitting cuts into smaller chunks")
|
logging.info("About splitting cuts into smaller chunks")
|
||||||
@ -94,7 +97,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
|||||||
min_duration=None,
|
min_duration=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
|
cut_set.to_json(output_dir / cuts_filename)
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -48,7 +48,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|||||||
if [ ! -f $dl_dir/aishell4/train_L ]; then
|
if [ ! -f $dl_dir/aishell4/train_L ]; then
|
||||||
lhotse download aishell4 $dl_dir/aishell4
|
lhotse download aishell4 $dl_dir/aishell4
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# If you have pre-downloaded it to /path/to/musan,
|
# If you have pre-downloaded it to /path/to/musan,
|
||||||
# you can create a symlink
|
# you can create a symlink
|
||||||
#
|
#
|
||||||
@ -117,9 +117,26 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
# Prepare text.
|
# Prepare text.
|
||||||
# Note: in Linux, you can install jq with the following command:
|
# Note: in Linux, you can install jq with the following command:
|
||||||
# wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
|
# wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
|
||||||
gunzip -c data/manifests/aishell4/supervisions_train_L.jsonl.gz \
|
gunzip -c data/manifests/aishell4/aishell4_supervisions_train_S.jsonl.gz \
|
||||||
| jq ".text" | sed 's/"//g' | sed 's/<sil>//g' \
|
| jq ".text" | sed 's/"//g' \
|
||||||
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
| ./local/text2token.py -t "char" > $lang_char_dir/text_S
|
||||||
|
|
||||||
|
gunzip -c data/manifests/aishell4/aishell4_supervisions_train_M.jsonl.gz \
|
||||||
|
| jq ".text" | sed 's/"//g' \
|
||||||
|
| ./local/text2token.py -t "char" > $lang_char_dir/text_M
|
||||||
|
|
||||||
|
gunzip -c data/manifests/aishell4/aishell4_supervisions_train_L.jsonl.gz \
|
||||||
|
| jq ".text" | sed 's/"//g' \
|
||||||
|
| ./local/text2token.py -t "char" > $lang_char_dir/text_L
|
||||||
|
|
||||||
|
for r in text_S text_M text_L ; do
|
||||||
|
cat $lang_char_dir/$r >> $lang_char_dir/text_full
|
||||||
|
done
|
||||||
|
|
||||||
|
# Prepare text normalize
|
||||||
|
python ./local/text_normalize.py \
|
||||||
|
--input $lang_char_dir/text_full \
|
||||||
|
--output $lang_char_dir/text
|
||||||
|
|
||||||
# Prepare words segments
|
# Prepare words segments
|
||||||
python ./local/text2segments.py \
|
python ./local/text2segments.py \
|
||||||
|
@ -23,14 +23,8 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import (
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
||||||
CutSet,
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
Fbank,
|
|
||||||
FbankConfig,
|
|
||||||
load_manifest,
|
|
||||||
set_caching_enabled,
|
|
||||||
)
|
|
||||||
from lhotse.dataset import (
|
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
@ -39,15 +33,15 @@ from lhotse.dataset import (
|
|||||||
SingleCutSampler,
|
SingleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples
|
||||||
|
AudioSamples,
|
||||||
|
OnTheFlyFeatures,
|
||||||
|
)
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
set_caching_enabled(False)
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
def __init__(self, seed: int):
|
def __init__(self, seed: int):
|
||||||
@ -85,12 +79,14 @@ class Aishell4AsrDataModule:
|
|||||||
"effective batch sizes, sampling strategies, applied data "
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
"augmentations, etc.",
|
"augmentations, etc.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=Path("data/fbank"),
|
default=Path("data/fbank"),
|
||||||
help="Path to directory with train/valid/test cuts.",
|
help="Path to directory with train/valid/test cuts.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--max-duration",
|
"--max-duration",
|
||||||
type=int,
|
type=int,
|
||||||
@ -98,6 +94,7 @@ class Aishell4AsrDataModule:
|
|||||||
help="Maximum pooled recordings duration (seconds) in a "
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--bucketing-sampler",
|
"--bucketing-sampler",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -105,6 +102,7 @@ class Aishell4AsrDataModule:
|
|||||||
help="When enabled, the batches will come from buckets of "
|
help="When enabled, the batches will come from buckets of "
|
||||||
"similar duration (saves padding frames).",
|
"similar duration (saves padding frames).",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-buckets",
|
"--num-buckets",
|
||||||
type=int,
|
type=int,
|
||||||
@ -112,6 +110,7 @@ class Aishell4AsrDataModule:
|
|||||||
help="The number of buckets for the DynamicBucketingSampler"
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
"(you might want to increase it for larger datasets).",
|
"(you might want to increase it for larger datasets).",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--concatenate-cuts",
|
"--concatenate-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -119,6 +118,7 @@ class Aishell4AsrDataModule:
|
|||||||
help="When enabled, utterances (cuts) will be concatenated "
|
help="When enabled, utterances (cuts) will be concatenated "
|
||||||
"to minimize the amount of padding.",
|
"to minimize the amount of padding.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--duration-factor",
|
"--duration-factor",
|
||||||
type=float,
|
type=float,
|
||||||
@ -126,6 +126,7 @@ class Aishell4AsrDataModule:
|
|||||||
help="Determines the maximum duration of a concatenated cut "
|
help="Determines the maximum duration of a concatenated cut "
|
||||||
"relative to the duration of the longest cut in a batch.",
|
"relative to the duration of the longest cut in a batch.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--gap",
|
"--gap",
|
||||||
type=float,
|
type=float,
|
||||||
@ -134,6 +135,7 @@ class Aishell4AsrDataModule:
|
|||||||
"concatenated cuts. This padding is filled with noise when "
|
"concatenated cuts. This padding is filled with noise when "
|
||||||
"noise augmentation is used.",
|
"noise augmentation is used.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--on-the-fly-feats",
|
"--on-the-fly-feats",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -142,6 +144,7 @@ class Aishell4AsrDataModule:
|
|||||||
"extraction. Will drop existing precomputed feature manifests "
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
"if available.",
|
"if available.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--shuffle",
|
"--shuffle",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -149,6 +152,14 @@ class Aishell4AsrDataModule:
|
|||||||
help="When enabled (=default), the examples will be "
|
help="When enabled (=default), the examples will be "
|
||||||
"shuffled for each epoch.",
|
"shuffled for each epoch.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--drop-last",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to drop last batch. Used by sampler.",
|
||||||
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--return-cuts",
|
"--return-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -192,10 +203,10 @@ class Aishell4AsrDataModule:
|
|||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--lazy-load",
|
"--input-strategy",
|
||||||
type=str2bool,
|
type=str,
|
||||||
default=True,
|
default="PrecomputedFeatures",
|
||||||
help="lazily open CutSets to avoid OOM (for L|XL subset)",
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
@ -218,8 +229,8 @@ class Aishell4AsrDataModule:
|
|||||||
The state dict for the training sampler.
|
The state dict for the training sampler.
|
||||||
"""
|
"""
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(
|
cuts_musan = load_manifest_lazy(
|
||||||
self.args.manifest_dir / "cuts_musan.json.gz"
|
self.args.manifest_dir / "cuts_musan.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
@ -277,6 +288,7 @@ class Aishell4AsrDataModule:
|
|||||||
|
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
|
input_strategy=eval(self.args.input_strategy)(),
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
@ -310,7 +322,7 @@ class Aishell4AsrDataModule:
|
|||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
buffer_size=30000,
|
buffer_size=30000,
|
||||||
drop_last=True,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SingleCutSampler.")
|
||||||
@ -367,8 +379,6 @@ class Aishell4AsrDataModule:
|
|||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
rank=0,
|
|
||||||
world_size=1,
|
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
@ -393,14 +403,12 @@ class Aishell4AsrDataModule:
|
|||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
else PrecomputedFeatures(),
|
else eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
cuts,
|
cuts,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
rank=0,
|
|
||||||
world_size=1,
|
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
||||||
@ -419,26 +427,22 @@ class Aishell4AsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
if self.args.lazy_load:
|
return load_manifest_lazy(
|
||||||
logging.info("use lazy cuts")
|
self.args.manifest_dir
|
||||||
cuts_train = CutSet.from_jsonl_lazy(
|
/ "aishell4_cuts_train_{self.args.training_subset}.jsonl.gz"
|
||||||
self.args.manifest_dir
|
)
|
||||||
/ f"cuts_train_{self.args.training_subset}.json.gz"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cuts_train = CutSet.from_file(
|
|
||||||
self.args.manifest_dir
|
|
||||||
/ f"cuts_train_{self.args.training_subset}.json.gz"
|
|
||||||
)
|
|
||||||
return cuts_train
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
# Aishell4 doesn't have dev data, here use test to replace dev.
|
# Aishell4 doesn't have dev data, here use test to replace dev.
|
||||||
return load_manifest(self.args.manifest_dir / "cuts_test.json.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "aishell4_cuts_test.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_cuts(self) -> List[CutSet]:
|
def test_cuts(self) -> List[CutSet]:
|
||||||
logging.info("About to get test cuts")
|
logging.info("About to get test cuts")
|
||||||
return load_manifest(self.args.manifest_dir / "cuts_test.json.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "aishell4_cuts_test.jsonl.gz"
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user