mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
disable speed perturbation by default (#1176)
* disable speed perturbation by default
* minor fixes
* minor updates
* updated bash scripts to incorporate with the `speed-perturb` arg
* minor fixes
1. changed the naming scheme from `speed-perturb` to `perturb-speed` to align with the librispeech recipe
>> 00256a7669/egs/librispeech/ASR/local/compute_fbank_librispeech.py (L65)
2. changed arg type for `perturb-speed` to str2bool
This commit is contained in:
parent
00256a7669
commit
74806b744b
@ -32,7 +32,7 @@ import torch
|
|||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests/aidatatang_200zh")
|
src_dir = Path("data/manifests/aidatatang_200zh")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -85,7 +85,8 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition and perturb_speed:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -109,7 +110,12 @@ def get_args():
|
|||||||
default=80,
|
default=80,
|
||||||
help="""The number of mel bins for Fbank""",
|
help="""The number of mel bins for Fbank""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -119,4 +125,6 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
|
compute_fbank_aidatatang_200zh(
|
||||||
|
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||||
|
)
|
||||||
|
@ -77,7 +77,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
log "Stage 4: Compute fbank for aidatatang_200zh"
|
log "Stage 4: Compute fbank for aidatatang_200zh"
|
||||||
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
|
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aidatatang_200zh.py
|
./local/compute_fbank_aidatatang_200zh.py --perturb-speed True
|
||||||
touch data/fbank/.aidatatang_200zh.done
|
touch data/fbank/.aidatatang_200zh.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -32,7 +32,7 @@ import torch
|
|||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -85,7 +85,8 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition and perturb_speed:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -109,7 +110,12 @@ def get_args():
|
|||||||
default=80,
|
default=80,
|
||||||
help="""The number of mel bins for Fbank""",
|
help="""The number of mel bins for Fbank""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -119,4 +125,6 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
|
compute_fbank_aidatatang_200zh(
|
||||||
|
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||||
|
)
|
||||||
|
@ -32,7 +32,7 @@ import torch
|
|||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell(num_mel_bins: int = 80):
|
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -81,7 +81,8 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition and perturb_speed:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -104,7 +105,12 @@ def get_args():
|
|||||||
default=80,
|
default=80,
|
||||||
help="""The number of mel bins for Fbank""",
|
help="""The number of mel bins for Fbank""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -114,4 +120,6 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell(num_mel_bins=args.num_mel_bins)
|
compute_fbank_aishell(
|
||||||
|
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||||
|
)
|
||||||
|
@ -114,7 +114,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
log "Stage 3: Compute fbank for aishell"
|
log "Stage 3: Compute fbank for aishell"
|
||||||
if [ ! -f data/fbank/.aishell.done ]; then
|
if [ ! -f data/fbank/.aishell.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aishell.py
|
./local/compute_fbank_aishell.py --perturb-speed True
|
||||||
touch data/fbank/.aishell.done
|
touch data/fbank/.aishell.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -53,7 +53,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
log "Stage 2: Process aidatatang_200zh"
|
log "Stage 2: Process aidatatang_200zh"
|
||||||
if [ ! -f data/fbank/.aidatatang_200zh_fbank.done ]; then
|
if [ ! -f data/fbank/.aidatatang_200zh_fbank.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aidatatang_200zh.py
|
./local/compute_fbank_aidatatang_200zh.py --perturb-speed True
|
||||||
touch data/fbank/.aidatatang_200zh_fbank.done
|
touch data/fbank/.aidatatang_200zh_fbank.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -32,7 +32,7 @@ import torch
|
|||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell2(num_mel_bins: int = 80):
|
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -81,7 +81,8 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition and perturb_speed:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -104,6 +105,12 @@ def get_args():
|
|||||||
default=80,
|
default=80,
|
||||||
help="""The number of mel bins for Fbank""",
|
help="""The number of mel bins for Fbank""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -114,4 +121,6 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell2(num_mel_bins=args.num_mel_bins)
|
compute_fbank_aishell2(
|
||||||
|
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||||
|
)
|
||||||
|
@ -101,7 +101,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
log "Stage 3: Compute fbank for aishell2"
|
log "Stage 3: Compute fbank for aishell2"
|
||||||
if [ ! -f data/fbank/.aishell2.done ]; then
|
if [ ! -f data/fbank/.aishell2.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aishell2.py
|
./local/compute_fbank_aishell2.py --perturb-speed True
|
||||||
touch data/fbank/.aishell2.done
|
touch data/fbank/.aishell2.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -32,7 +32,7 @@ import torch
|
|||||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell4(num_mel_bins: int = 80):
|
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests/aishell4")
|
src_dir = Path("data/manifests/aishell4")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -83,10 +83,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition and perturb_speed:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
|
|
||||||
cut_set = cut_set.compute_and_store_features(
|
cut_set = cut_set.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||||
@ -113,6 +115,12 @@ def get_args():
|
|||||||
default=80,
|
default=80,
|
||||||
help="""The number of mel bins for Fbank""",
|
help="""The number of mel bins for Fbank""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -123,4 +131,6 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell4(num_mel_bins=args.num_mel_bins)
|
compute_fbank_aishell4(
|
||||||
|
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||||
|
)
|
||||||
|
@ -107,7 +107,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
log "Stage 5: Compute fbank for aishell4"
|
log "Stage 5: Compute fbank for aishell4"
|
||||||
if [ ! -f data/fbank/.aishell4.done ]; then
|
if [ ! -f data/fbank/.aishell4.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aishell4.py
|
./local/compute_fbank_aishell4.py --perturb-speed True
|
||||||
touch data/fbank/.aishell4.done
|
touch data/fbank/.aishell4.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -32,7 +32,7 @@ import torch
|
|||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_alimeeting(num_mel_bins: int = 80):
|
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests/alimeeting")
|
src_dir = Path("data/manifests/alimeeting")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -82,7 +82,8 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition and perturb_speed:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -114,6 +115,12 @@ def get_args():
|
|||||||
default=80,
|
default=80,
|
||||||
help="""The number of mel bins for Fbank""",
|
help="""The number of mel bins for Fbank""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -124,4 +131,6 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_alimeeting(num_mel_bins=args.num_mel_bins)
|
compute_fbank_alimeeting(
|
||||||
|
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||||
|
)
|
||||||
|
@ -97,7 +97,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
log "Stage 5: Compute fbank for alimeeting"
|
log "Stage 5: Compute fbank for alimeeting"
|
||||||
if [ ! -f data/fbank/.alimeeting.done ]; then
|
if [ ! -f data/fbank/.alimeeting.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_alimeeting.py
|
./local/compute_fbank_alimeeting.py --perturb-speed True
|
||||||
touch data/fbank/.alimeeting.done
|
touch data/fbank/.alimeeting.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -25,6 +25,7 @@ It looks for manifests in the directory data/manifests.
|
|||||||
|
|
||||||
The generated fbank features are saved in data/fbank.
|
The generated fbank features are saved in data/fbank.
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -39,6 +40,8 @@ from lhotse.features.kaldifeat import (
|
|||||||
)
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
# Do this outside of main() in case it needs to take effect
|
# Do this outside of main() in case it needs to take effect
|
||||||
@ -48,7 +51,7 @@ torch.set_num_interop_threads(1)
|
|||||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_ami():
|
def compute_fbank_ami(perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
|
|
||||||
@ -84,8 +87,12 @@ def compute_fbank_ami():
|
|||||||
suffix="jsonl.gz",
|
suffix="jsonl.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None:
|
def _extract_feats(
|
||||||
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
|
cuts: CutSet, storage_path: Path, manifest_path: Path, speed_perturb: bool
|
||||||
|
) -> None:
|
||||||
|
if speed_perturb:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
|
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
|
||||||
_ = cuts.compute_and_store_features_batch(
|
_ = cuts.compute_and_store_features_batch(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=storage_path,
|
storage_path=storage_path,
|
||||||
@ -109,6 +116,7 @@ def compute_fbank_ami():
|
|||||||
cuts_ihm,
|
cuts_ihm,
|
||||||
output_dir / "feats_train_ihm",
|
output_dir / "feats_train_ihm",
|
||||||
src_dir / "cuts_train_ihm.jsonl.gz",
|
src_dir / "cuts_train_ihm.jsonl.gz",
|
||||||
|
perturb_speed,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Processing train split IHM + reverberated IHM")
|
logging.info("Processing train split IHM + reverberated IHM")
|
||||||
@ -117,6 +125,7 @@ def compute_fbank_ami():
|
|||||||
cuts_ihm_rvb,
|
cuts_ihm_rvb,
|
||||||
output_dir / "feats_train_ihm_rvb",
|
output_dir / "feats_train_ihm_rvb",
|
||||||
src_dir / "cuts_train_ihm_rvb.jsonl.gz",
|
src_dir / "cuts_train_ihm_rvb.jsonl.gz",
|
||||||
|
perturb_speed,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Processing train split SDM")
|
logging.info("Processing train split SDM")
|
||||||
@ -129,6 +138,7 @@ def compute_fbank_ami():
|
|||||||
cuts_sdm,
|
cuts_sdm,
|
||||||
output_dir / "feats_train_sdm",
|
output_dir / "feats_train_sdm",
|
||||||
src_dir / "cuts_train_sdm.jsonl.gz",
|
src_dir / "cuts_train_sdm.jsonl.gz",
|
||||||
|
perturb_speed,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Processing train split GSS")
|
logging.info("Processing train split GSS")
|
||||||
@ -141,6 +151,7 @@ def compute_fbank_ami():
|
|||||||
cuts_gss,
|
cuts_gss,
|
||||||
output_dir / "feats_train_gss",
|
output_dir / "feats_train_gss",
|
||||||
src_dir / "cuts_train_gss.jsonl.gz",
|
src_dir / "cuts_train_gss.jsonl.gz",
|
||||||
|
perturb_speed,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
|
logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
|
||||||
@ -186,8 +197,21 @@ def compute_fbank_ami():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
compute_fbank_ami()
|
args = get_args()
|
||||||
|
|
||||||
|
compute_fbank_ami(perturb_speed=args.perturb_speed)
|
||||||
|
@ -85,7 +85,7 @@ fi
|
|||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 5: Compute fbank for alimeeting"
|
log "Stage 5: Compute fbank for alimeeting"
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
python local/compute_fbank_alimeeting.py
|
python local/compute_fbank_alimeeting.py --perturb-speed True
|
||||||
log "Combine features from train splits"
|
log "Combine features from train splits"
|
||||||
lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
|
lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
|
||||||
gzip -c > data/manifests/cuts_train_all.jsonl.gz
|
gzip -c > data/manifests/cuts_train_all.jsonl.gz
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -24,6 +25,7 @@ from lhotse import CutSet, SupervisionSegment
|
|||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall import setup_logger
|
from icefall import setup_logger
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
# Similar text filtering and normalization procedure as in:
|
# Similar text filtering and normalization procedure as in:
|
||||||
# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh
|
# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh
|
||||||
@ -45,7 +47,7 @@ def has_no_oov(
|
|||||||
return oov_pattern.search(sup.text) is None
|
return oov_pattern.search(sup.text) is None
|
||||||
|
|
||||||
|
|
||||||
def preprocess_wenet_speech():
|
def preprocess_wenet_speech(perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
output_dir.mkdir(exist_ok=True)
|
output_dir.mkdir(exist_ok=True)
|
||||||
@ -110,7 +112,7 @@ def preprocess_wenet_speech():
|
|||||||
)
|
)
|
||||||
# Run data augmentation that needs to be done in the
|
# Run data augmentation that needs to be done in the
|
||||||
# time domain.
|
# time domain.
|
||||||
if partition not in ["DEV", "TEST_NET", "TEST_MEETING"]:
|
if partition not in ["DEV", "TEST_NET", "TEST_MEETING"] and perturb_speed:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
|
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
|
||||||
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
|
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
|
||||||
@ -120,10 +122,22 @@ def preprocess_wenet_speech():
|
|||||||
cut_set.to_file(raw_cuts_path)
|
cut_set.to_file(raw_cuts_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
setup_logger(log_filename="./log-preprocess-wenetspeech")
|
setup_logger(log_filename="./log-preprocess-wenetspeech")
|
||||||
|
|
||||||
preprocess_wenet_speech()
|
args = get_args()
|
||||||
|
preprocess_wenet_speech(perturb_speed=args.perturb_speed)
|
||||||
logging.info("Done")
|
logging.info("Done")
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ fi
|
|||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 3: Preprocess WenetSpeech manifest"
|
log "Stage 3: Preprocess WenetSpeech manifest"
|
||||||
if [ ! -f data/fbank/.preprocess_complete ]; then
|
if [ ! -f data/fbank/.preprocess_complete ]; then
|
||||||
python3 ./local/preprocess_wenetspeech.py
|
python3 ./local/preprocess_wenetspeech.py --perturb-speed True
|
||||||
touch data/fbank/.preprocess_complete
|
touch data/fbank/.preprocess_complete
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
Loading…
x
Reference in New Issue
Block a user