do some changes and text normalize

This commit is contained in:
luomingshuang 2022-06-07 12:16:51 +08:00
parent 4215ec434a
commit ddc55423b1
3 changed files with 73 additions and 49 deletions

View File

@ -29,7 +29,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
@ -53,11 +53,13 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
"train_L", "train_L",
"test", "test",
) )
prefix = "aishell4"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached( manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, dataset_parts=dataset_parts,
output_dir=src_dir, output_dir=src_dir,
prefix="aishell4", prefix=prefix,
suffix="jsonl.gz", suffix=suffix,
) )
assert manifests is not None assert manifests is not None
@ -65,7 +67,8 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items(): for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.jsonl").is_file(): cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.") logging.info(f"{partition} already exists - skipping.")
continue continue
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
@ -81,11 +84,11 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
) )
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
storage_path=f"{output_dir}/feats_{partition}", storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions # when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80, num_jobs=num_jobs if ex is None else 80,
executor=ex, executor=ex,
storage_type=ChunkedLilcomHdf5Writer, storage_type=LilcomChunkyWriter,
) )
logging.info("About splitting cuts into smaller chunks") logging.info("About splitting cuts into smaller chunks")
@ -94,7 +97,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
min_duration=None, min_duration=None,
) )
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") cut_set.to_json(output_dir / cuts_filename)
def get_args(): def get_args():

View File

@ -48,7 +48,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
if [ ! -f $dl_dir/aishell4/train_L ]; then if [ ! -f $dl_dir/aishell4/train_L ]; then
lhotse download aishell4 $dl_dir/aishell4 lhotse download aishell4 $dl_dir/aishell4
fi fi
# If you have pre-downloaded it to /path/to/musan, # If you have pre-downloaded it to /path/to/musan,
# you can create a symlink # you can create a symlink
# #
@ -117,9 +117,26 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
# Prepare text. # Prepare text.
# Note: in Linux, you can install jq with the following command: # Note: in Linux, you can install jq with the following command:
# wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 # wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
gunzip -c data/manifests/aishell4/supervisions_train_L.jsonl.gz \ gunzip -c data/manifests/aishell4/aishell4_supervisions_train_S.jsonl.gz \
| jq ".text" | sed 's/"//g' | sed 's/<sil>//g' \ | jq ".text" | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text | ./local/text2token.py -t "char" > $lang_char_dir/text_S
gunzip -c data/manifests/aishell4/aishell4_supervisions_train_M.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text_M
gunzip -c data/manifests/aishell4/aishell4_supervisions_train_L.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text_L
for r in text_S text_M text_L ; do
cat $lang_char_dir/$r >> $lang_char_dir/text_full
done
# Prepare text normalize
python ./local/text_normalize.py \
--input $lang_char_dir/text_full \
--output $lang_char_dir/text
# Prepare words segments # Prepare words segments
python ./local/text2segments.py \ python ./local/text2segments.py \

View File

@ -23,14 +23,8 @@ from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
from lhotse import ( from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
CutSet, from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
Fbank,
FbankConfig,
load_manifest,
set_caching_enabled,
)
from lhotse.dataset import (
CutConcatenate, CutConcatenate,
CutMix, CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
@ -39,15 +33,15 @@ from lhotse.dataset import (
SingleCutSampler, SingleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
set_caching_enabled(False)
torch.set_num_threads(1)
class _SeedWorkers: class _SeedWorkers:
def __init__(self, seed: int): def __init__(self, seed: int):
@ -85,12 +79,14 @@ class Aishell4AsrDataModule:
"effective batch sizes, sampling strategies, applied data " "effective batch sizes, sampling strategies, applied data "
"augmentations, etc.", "augmentations, etc.",
) )
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
type=Path, type=Path,
default=Path("data/fbank"), default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.", help="Path to directory with train/valid/test cuts.",
) )
group.add_argument( group.add_argument(
"--max-duration", "--max-duration",
type=int, type=int,
@ -98,6 +94,7 @@ class Aishell4AsrDataModule:
help="Maximum pooled recordings duration (seconds) in a " help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.", "single batch. You can reduce it if it causes CUDA OOM.",
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
@ -105,6 +102,7 @@ class Aishell4AsrDataModule:
help="When enabled, the batches will come from buckets of " help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).", "similar duration (saves padding frames).",
) )
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
@ -112,6 +110,7 @@ class Aishell4AsrDataModule:
help="The number of buckets for the DynamicBucketingSampler" help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).", "(you might want to increase it for larger datasets).",
) )
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
@ -119,6 +118,7 @@ class Aishell4AsrDataModule:
help="When enabled, utterances (cuts) will be concatenated " help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.", "to minimize the amount of padding.",
) )
group.add_argument( group.add_argument(
"--duration-factor", "--duration-factor",
type=float, type=float,
@ -126,6 +126,7 @@ class Aishell4AsrDataModule:
help="Determines the maximum duration of a concatenated cut " help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.", "relative to the duration of the longest cut in a batch.",
) )
group.add_argument( group.add_argument(
"--gap", "--gap",
type=float, type=float,
@ -134,6 +135,7 @@ class Aishell4AsrDataModule:
"concatenated cuts. This padding is filled with noise when " "concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.", "noise augmentation is used.",
) )
group.add_argument( group.add_argument(
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
@ -142,6 +144,7 @@ class Aishell4AsrDataModule:
"extraction. Will drop existing precomputed feature manifests " "extraction. Will drop existing precomputed feature manifests "
"if available.", "if available.",
) )
group.add_argument( group.add_argument(
"--shuffle", "--shuffle",
type=str2bool, type=str2bool,
@ -149,6 +152,14 @@ class Aishell4AsrDataModule:
help="When enabled (=default), the examples will be " help="When enabled (=default), the examples will be "
"shuffled for each epoch.", "shuffled for each epoch.",
) )
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument( group.add_argument(
"--return-cuts", "--return-cuts",
type=str2bool, type=str2bool,
@ -192,10 +203,10 @@ class Aishell4AsrDataModule:
) )
group.add_argument( group.add_argument(
"--lazy-load", "--input-strategy",
type=str2bool, type=str,
default=True, default="PrecomputedFeatures",
help="lazily open CutSets to avoid OOM (for L|XL subset)", help="AudioSamples or PrecomputedFeatures",
) )
group.add_argument( group.add_argument(
@ -218,8 +229,8 @@ class Aishell4AsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest( cuts_musan = load_manifest_lazy(
self.args.manifest_dir / "cuts_musan.json.gz" self.args.manifest_dir / "cuts_musan.jsonl.gz"
) )
transforms = [] transforms = []
@ -277,6 +288,7 @@ class Aishell4AsrDataModule:
logging.info("About to create train dataset") logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms, cut_transforms=transforms,
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
@ -310,7 +322,7 @@ class Aishell4AsrDataModule:
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
buffer_size=30000, buffer_size=30000,
drop_last=True, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SingleCutSampler.")
@ -367,8 +379,6 @@ class Aishell4AsrDataModule:
valid_sampler = DynamicBucketingSampler( valid_sampler = DynamicBucketingSampler(
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False, shuffle=False,
) )
logging.info("About to create dev dataloader") logging.info("About to create dev dataloader")
@ -393,14 +403,12 @@ class Aishell4AsrDataModule:
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
else PrecomputedFeatures(), else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
cuts, cuts,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False, shuffle=False,
) )
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
@ -419,26 +427,22 @@ class Aishell4AsrDataModule:
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
if self.args.lazy_load: return load_manifest_lazy(
logging.info("use lazy cuts") self.args.manifest_dir
cuts_train = CutSet.from_jsonl_lazy( / "aishell4_cuts_train_{self.args.training_subset}.jsonl.gz"
self.args.manifest_dir )
/ f"cuts_train_{self.args.training_subset}.json.gz"
)
else:
cuts_train = CutSet.from_file(
self.args.manifest_dir
/ f"cuts_train_{self.args.training_subset}.json.gz"
)
return cuts_train
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
# Aishell4 doesn't have dev data, here use test to replace dev. # Aishell4 doesn't have dev data, here use test to replace dev.
return load_manifest(self.args.manifest_dir / "cuts_test.json.gz") return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_test.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_cuts(self) -> List[CutSet]: def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts") logging.info("About to get test cuts")
return load_manifest(self.args.manifest_dir / "cuts_test.json.gz") return load_manifest_lazy(
self.args.manifest_dir / "aishell4_cuts_test.jsonl.gz"
)