clean codes

This commit is contained in:
Yuekai Zhang 2024-01-15 20:40:44 +08:00
parent eea46458c5
commit 557b35cefc
8 changed files with 266 additions and 575 deletions

View File

@ -29,7 +29,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,7 +42,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False): def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
@ -68,8 +68,10 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
list(manifests.keys()), list(manifests.keys()),
dataset_parts, dataset_parts,
) )
if whisper_fbank:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items(): for partition, m in manifests.items():
@ -111,6 +113,12 @@ def get_args():
default=False, default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -121,5 +129,5 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_aishell( compute_fbank_aishell(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
) )

View File

@ -1,125 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the aishell dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
dataset_parts = (
"train",
"dev",
"test",
)
prefix = "aishell"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_aishell(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
)

View File

@ -1,109 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the musan dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter, MonoCut, combine
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def is_cut_long(c: MonoCut) -> bool:
return c.duration > 5
def compute_fbank_musan():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
num_mel_bins = 80
dataset_parts = (
"music",
"speech",
"noise",
)
prefix = "musan"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
musan_cuts_path = output_dir / "musan_cuts.jsonl.gz"
if musan_cuts_path.is_file():
logging.info(f"{musan_cuts_path} already exists - skipping")
return
logging.info("Extracting features for Musan")
extractor = WhisperFbank(WhisperFbankConfig(device='cuda'))
with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (
CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values())
)
.cut_into_windows(10.0)
.filter(is_cut_long)
.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/musan_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
)
musan_cuts.to_file(musan_cuts_path)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan()

View File

@ -1,5 +1,5 @@
#!/usr/bin/env bash #!/usr/bin/env bash
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
@ -83,9 +83,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# #
# ln -sfv /path/to/musan $dl_dir/musan # ln -sfv /path/to/musan $dl_dir/musan
# #
# if [ ! -d $dl_dir/musan ]; then if [ ! -d $dl_dir/musan ]; then
# lhotse download musan $dl_dir lhotse download musan $dl_dir
# fi fi
fi fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
@ -99,17 +99,17 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi fi
fi fi
# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# log "Stage 2: Prepare musan manifest" log "Stage 2: Prepare musan manifest"
# # We assume that you have downloaded the musan corpus # We assume that you have downloaded the musan corpus
# # to data/musan # to data/musan
# if [ ! -f data/manifests/.musan_manifests.done ]; then if [ ! -f data/manifests/.musan_manifests.done ]; then
# log "It may take 6 minutes" log "It may take 6 minutes"
# mkdir -p data/manifests mkdir -p data/manifests
# lhotse prepare musan $dl_dir/musan data/manifests lhotse prepare musan $dl_dir/musan data/manifests
# touch data/manifests/.musan_manifests.done touch data/manifests/.musan_manifests.done
# fi fi
# fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for aishell" log "Stage 3: Compute fbank for aishell"
@ -120,56 +120,29 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi fi
fi fi
# if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
# log "Stage 30: Compute whisper fbank for aishell" log "Stage 4: Compute fbank for musan"
# if [ ! -f data/fbank/.aishell.done ]; then if [ ! -f data/fbank/.msuan.done ]; then
# mkdir -p data/fbank
# ./local/compute_whisper_fbank_aishell.py --perturb-speed True
# touch data/fbank/.aishell.done
# fi
# fi
if [ $stage -le 300 ] && [ $stop_stage -ge 300 ]; then
log "Stage 30: Compute whisper fbank for aishell"
if [ ! -f data/fbank/.aishell.done ]; then
mkdir -p data/fbank mkdir -p data/fbank
./local/compute_whisper_fbank_aishell.py --perturb-speed True --num-mel-bins 128 ./local/compute_fbank_musan.py
touch data/fbank/.aishell.done touch data/fbank/.msuan.done
fi fi
fi fi
# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then lang_phone_dir=data/lang_phone
# log "Stage 4: Compute fbank for musan" if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
# if [ ! -f data/fbank/.msuan.done ]; then log "Stage 5: Prepare phone based lang"
# mkdir -p data/fbank mkdir -p $lang_phone_dir
# ./local/compute_fbank_musan.py
# touch data/fbank/.msuan.done
# fi
# fi
# if [ $stage -le 40 ] && [ $stop_stage -ge 40 ]; then (echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
# log "Stage 4: Compute fbank for musan" cat - $dl_dir/aishell/resource_aishell/lexicon.txt |
# if [ ! -f data/fbank/.msuan.done ]; then sort | uniq > $lang_phone_dir/lexicon.txt
# mkdir -p data/fbank
# ./local/compute_whisper_fbank_musan.py
# touch data/fbank/.msuan.done
# fi
# fi
# lang_phone_dir=data/lang_phone ./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir
# if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
# log "Stage 5: Prepare phone based lang"
# mkdir -p $lang_phone_dir
# (echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) | if [ ! -f $lang_phone_dir/L_disambig.pt ]; then
# cat - $dl_dir/aishell/resource_aishell/lexicon.txt | ./local/prepare_lang.py --lang-dir $lang_phone_dir
# sort | uniq > $lang_phone_dir/lexicon.txt fi
# ./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir
# if [ ! -f $lang_phone_dir/L_disambig.pt ]; then
# ./local/prepare_lang.py --lang-dir $lang_phone_dir
# fi
# Train a bigram P for MMI training # Train a bigram P for MMI training
@ -182,93 +155,93 @@ fi
cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
fi fi
# if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then
# ./local/convert_transcript_words_to_tokens.py \ ./local/convert_transcript_words_to_tokens.py \
# --lexicon $lang_phone_dir/uniq_lexicon.txt \ --lexicon $lang_phone_dir/uniq_lexicon.txt \
# --transcript $lang_phone_dir/transcript_words.txt \ --transcript $lang_phone_dir/transcript_words.txt \
# --oov "<UNK>" \ --oov "<UNK>" \
# > $lang_phone_dir/transcript_tokens.txt > $lang_phone_dir/transcript_tokens.txt
# fi fi
# if [ ! -f $lang_phone_dir/P.arpa ]; then if [ ! -f $lang_phone_dir/P.arpa ]; then
# ./shared/make_kn_lm.py \ ./shared/make_kn_lm.py \
# -ngram-order 2 \ -ngram-order 2 \
# -text $lang_phone_dir/transcript_tokens.txt \ -text $lang_phone_dir/transcript_tokens.txt \
# -lm $lang_phone_dir/P.arpa -lm $lang_phone_dir/P.arpa
# fi fi
# if [ ! -f $lang_phone_dir/P.fst.txt ]; then if [ ! -f $lang_phone_dir/P.fst.txt ]; then
# python3 -m kaldilm \ python3 -m kaldilm \
# --read-symbol-table="$lang_phone_dir/tokens.txt" \ --read-symbol-table="$lang_phone_dir/tokens.txt" \
# --disambig-symbol='#0' \ --disambig-symbol='#0' \
# --max-order=2 \ --max-order=2 \
# $lang_phone_dir/P.arpa > $lang_phone_dir/P.fst.txt $lang_phone_dir/P.arpa > $lang_phone_dir/P.fst.txt
# fi fi
# fi fi
# lang_char_dir=data/lang_char lang_char_dir=data/lang_char
# if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
# log "Stage 6: Prepare char based lang" log "Stage 6: Prepare char based lang"
# mkdir -p $lang_char_dir mkdir -p $lang_char_dir
# # We reuse words.txt from phone based lexicon # We reuse words.txt from phone based lexicon
# # so that the two can share G.pt later. # so that the two can share G.pt later.
# # The transcripts in training set, generated in stage 5 # The transcripts in training set, generated in stage 5
# cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt
# cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt | cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt |
# cut -d " " -f 2- > $lang_char_dir/text cut -d " " -f 2- > $lang_char_dir/text
# (echo '<eps> 0'; echo '!SIL 1'; echo '<SPOKEN_NOISE> 2'; echo '<UNK> 3';) \ (echo '<eps> 0'; echo '!SIL 1'; echo '<SPOKEN_NOISE> 2'; echo '<UNK> 3';) \
# > $lang_char_dir/words.txt > $lang_char_dir/words.txt
# cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
# | awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt | awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt
# num_lines=$(< $lang_char_dir/words.txt wc -l) num_lines=$(< $lang_char_dir/words.txt wc -l)
# (echo "#0 $num_lines"; echo "<s> $(($num_lines + 1))"; echo "</s> $(($num_lines + 2))";) \ (echo "#0 $num_lines"; echo "<s> $(($num_lines + 1))"; echo "</s> $(($num_lines + 2))";) \
# >> $lang_char_dir/words.txt >> $lang_char_dir/words.txt
# if [ ! -f $lang_char_dir/L_disambig.pt ]; then if [ ! -f $lang_char_dir/L_disambig.pt ]; then
# ./local/prepare_char.py --lang-dir $lang_char_dir ./local/prepare_char.py --lang-dir $lang_char_dir
# fi fi
# fi fi
# if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
# log "Stage 7: Prepare Byte BPE based lang" log "Stage 7: Prepare Byte BPE based lang"
# for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
# lang_dir=data/lang_bbpe_${vocab_size} lang_dir=data/lang_bbpe_${vocab_size}
# mkdir -p $lang_dir mkdir -p $lang_dir
# cp $lang_char_dir/words.txt $lang_dir cp $lang_char_dir/words.txt $lang_dir
# cp $lang_char_dir/text $lang_dir cp $lang_char_dir/text $lang_dir
# if [ ! -f $lang_dir/bbpe.model ]; then if [ ! -f $lang_dir/bbpe.model ]; then
# ./local/train_bbpe_model.py \ ./local/train_bbpe_model.py \
# --lang-dir $lang_dir \ --lang-dir $lang_dir \
# --vocab-size $vocab_size \ --vocab-size $vocab_size \
# --transcript $lang_dir/text --transcript $lang_dir/text
# fi fi
# if [ ! -f $lang_dir/L_disambig.pt ]; then if [ ! -f $lang_dir/L_disambig.pt ]; then
# ./local/prepare_lang_bbpe.py --lang-dir $lang_dir ./local/prepare_lang_bbpe.py --lang-dir $lang_dir
# fi fi
# done done
# fi fi
# if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
# log "Stage 8: Prepare G" log "Stage 8: Prepare G"
# mkdir -p data/lm mkdir -p data/lm
# # Train LM on transcripts # Train LM on transcripts
# if [ ! -f data/lm/3-gram.unpruned.arpa ]; then if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
# python3 ./shared/make_kn_lm.py \ python3 ./shared/make_kn_lm.py \
# -ngram-order 3 \ -ngram-order 3 \
# -text $lang_char_dir/transcript_words.txt \ -text $lang_char_dir/transcript_words.txt \
# -lm data/lm/3-gram.unpruned.arpa -lm data/lm/3-gram.unpruned.arpa
# fi fi
# We assume you have installed kaldilm, if not, please install # We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm # it using: pip install kaldilm
@ -294,112 +267,124 @@ fi
fi fi
fi fi
# if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
# log "Stage 9: Compile LG & HLG" log "Stage 9: Compile LG & HLG"
# ./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone ./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
# ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char
# for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
# lang_dir=data/lang_bbpe_${vocab_size} lang_dir=data/lang_bbpe_${vocab_size}
# ./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char ./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char
# done done
# ./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone ./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
# ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char
# for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
# lang_dir=data/lang_bbpe_${vocab_size} lang_dir=data/lang_bbpe_${vocab_size}
# ./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char ./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char
# done done
# fi fi
# if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
# log "Stage 10: Generate LM training data" log "Stage 10: Generate LM training data"
# log "Processing char based data" log "Processing char based data"
# out_dir=data/lm_training_char out_dir=data/lm_training_char
# mkdir -p $out_dir $dl_dir/lm mkdir -p $out_dir $dl_dir/lm
# if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then
# cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt
# fi fi
# # training words # training words
# ./local/prepare_char_lm_training_data.py \ ./local/prepare_char_lm_training_data.py \
# --lang-char data/lang_char \ --lang-char data/lang_char \
# --lm-data $dl_dir/lm/aishell-train-word.txt \ --lm-data $dl_dir/lm/aishell-train-word.txt \
# --lm-archive $out_dir/lm_data.pt --lm-archive $out_dir/lm_data.pt
# # valid words # valid words
# if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then
# aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
# aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid
# find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid
# awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text | awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text |
# cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt
# fi fi
# ./local/prepare_char_lm_training_data.py \ ./local/prepare_char_lm_training_data.py \
# --lang-char data/lang_char \ --lang-char data/lang_char \
# --lm-data $dl_dir/lm/aishell-valid-word.txt \ --lm-data $dl_dir/lm/aishell-valid-word.txt \
# --lm-archive $out_dir/lm_data_valid.pt --lm-archive $out_dir/lm_data_valid.pt
# # test words # test words
# if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then
# aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
# aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid
# find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid
# awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text | awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text |
# cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt
# fi fi
# ./local/prepare_char_lm_training_data.py \ ./local/prepare_char_lm_training_data.py \
# --lang-char data/lang_char \ --lang-char data/lang_char \
# --lm-data $dl_dir/lm/aishell-test-word.txt \ --lm-data $dl_dir/lm/aishell-test-word.txt \
# --lm-archive $out_dir/lm_data_test.pt --lm-archive $out_dir/lm_data_test.pt
# fi fi
# if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
# log "Stage 11: Sort LM training data" log "Stage 11: Sort LM training data"
# # Sort LM training data by sentence length in descending order # Sort LM training data by sentence length in descending order
# # for ease of training. # for ease of training.
# # #
# # Sentence length equals to the number of tokens # Sentence length equals to the number of tokens
# # in a sentence. # in a sentence.
# out_dir=data/lm_training_char out_dir=data/lm_training_char
# mkdir -p $out_dir mkdir -p $out_dir
# ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/ ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/
# ./local/sort_lm_training_data.py \ ./local/sort_lm_training_data.py \
# --in-lm-data $out_dir/lm_data.pt \ --in-lm-data $out_dir/lm_data.pt \
# --out-lm-data $out_dir/sorted_lm_data.pt \ --out-lm-data $out_dir/sorted_lm_data.pt \
# --out-statistics $out_dir/statistics.txt --out-statistics $out_dir/statistics.txt
# ./local/sort_lm_training_data.py \ ./local/sort_lm_training_data.py \
# --in-lm-data $out_dir/lm_data_valid.pt \ --in-lm-data $out_dir/lm_data_valid.pt \
# --out-lm-data $out_dir/sorted_lm_data-valid.pt \ --out-lm-data $out_dir/sorted_lm_data-valid.pt \
# --out-statistics $out_dir/statistics-valid.txt --out-statistics $out_dir/statistics-valid.txt
# ./local/sort_lm_training_data.py \ ./local/sort_lm_training_data.py \
# --in-lm-data $out_dir/lm_data_test.pt \ --in-lm-data $out_dir/lm_data_test.pt \
# --out-lm-data $out_dir/sorted_lm_data-test.pt \ --out-lm-data $out_dir/sorted_lm_data-test.pt \
# --out-statistics $out_dir/statistics-test.txt --out-statistics $out_dir/statistics-test.txt
# fi fi
# if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
# log "Stage 11: Train RNN LM model" log "Stage 11: Train RNN LM model"
# python ../../../icefall/rnn_lm/train.py \ python ../../../icefall/rnn_lm/train.py \
# --start-epoch 0 \ --start-epoch 0 \
# --world-size 1 \ --world-size 1 \
# --num-epochs 20 \ --num-epochs 20 \
# --use-fp16 0 \ --use-fp16 0 \
# --embedding-dim 512 \ --embedding-dim 512 \
# --hidden-dim 512 \ --hidden-dim 512 \
# --num-layers 2 \ --num-layers 2 \
# --batch-size 400 \ --batch-size 400 \
# --exp-dir rnnlm_char/exp \ --exp-dir rnnlm_char/exp \
# --lm-data $out_dir/sorted_lm_data.pt \ --lm-data $out_dir/sorted_lm_data.pt \
# --lm-data-valid $out_dir/sorted_lm_data-valid.pt \ --lm-data-valid $out_dir/sorted_lm_data-valid.pt \
# --vocab-size 4336 \ --vocab-size 4336 \
# --master-port 12345 --master-port 12345
# fi fi
# whisper large-v3 using 128 mel bins, others using 80 mel bins
whisper_mel_bins=80
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
log "Stage 30: Compute ${whisper_mel_bins} dim fbank for whisper model fine-tuning"
if [ ! -f data/fbank/.aishell.whisper.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.aishell.whisper.done
fi
fi

40
egs/aishell/ASR/whisper/decode.py Normal file → Executable file
View File

@ -115,10 +115,8 @@ def remove_punctuation(text: str or List[str]):
result_text.append(t) result_text.append(t)
return result_text return result_text
else: else:
raise Exception(f'不支持该类型{type(text)}') raise Exception(f'Not support type {type(text)}')
# 将繁体中文总成简体中文
def to_simple(text: str or List[str]): def to_simple(text: str or List[str]):
if isinstance(text, str): if isinstance(text, str):
text = convert(text, 'zh-cn') text = convert(text, 'zh-cn')
@ -130,7 +128,7 @@ def to_simple(text: str or List[str]):
result_text.append(t) result_text.append(t)
return result_text return result_text
else: else:
raise Exception(f'不支持该类型{type(text)}') raise Exception(f'Not support type{type(text)}')
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -192,27 +190,11 @@ def get_parser():
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
# parameters for conformer
"subsampling_factor": 4,
"feature_dim": 80,
"nhead": 4,
"attention_dim": 512,
"num_encoder_layers": 12,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"search_beam": 20,
"output_beam": 7,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
return params return params
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -264,12 +246,6 @@ def decode_one_batch(
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device, dtype=dtype).transpose(1, 2) feature = feature.to(device, dtype=dtype).transpose(1, 2)
# pad feature to T = 3000
#T = 3000
#if feature.shape[2] < T:
# feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
print(feature.shape,23333)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"] feature_len = supervisions["num_frames"]
@ -281,7 +257,6 @@ def decode_one_batch(
hyps = to_simple(hyps) hyps = to_simple(hyps)
hyps = [params.normalizer.normalize(hyp) for hyp in hyps] hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
print(hyps, 233333333)
key = "beam-search" key = "beam-search"
@ -467,17 +442,14 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
aishell = AishellAsrDataModule(args) aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
#test_sets = ["test"] test_dl = aishell.test_dataloaders(aishell.test_cuts())
#test_dls = [test_dl] test_sets = ["valid", "test"]
test_sets = ["valid"] test_dls = [valid_dl, test_dl]
test_dls = [valid_dl]
for test_set, test_dl in zip(test_sets, test_dls): for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,

13
egs/aishell/ASR/whisper/model.py Normal file → Executable file
View File

@ -168,10 +168,8 @@ class AudioEncoder(nn.Module):
x = F.gelu(self.conv2(x)) x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" # change whisper to process audio with any length
x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype) x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype)
#x = (x + self.positional_embedding).to(x.dtype)
for block in self.blocks: for block in self.blocks:
x = block(x) x = block(x)
@ -224,7 +222,6 @@ class TextDecoder(nn.Module):
return logits return logits
class Whisper(nn.Module): class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions): def __init__(self, dims: ModelDimensions):
super().__init__() super().__init__()
@ -315,7 +312,6 @@ class Whisper(nn.Module):
self.decoder.apply(install_hooks) self.decoder.apply(install_hooks)
return cache, hooks return cache, hooks
#detect_language = detect_language_function
transcribe = transcribe_function transcribe = transcribe_function
decode = decode_function decode = decode_function
@ -432,9 +428,4 @@ def load_model(
model = Whisper(dims) model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
# if alignment_heads is not None: return model.to(device)
# model.set_alignment_heads(alignment_heads)
return model.to(device)

80
egs/aishell/ASR/whisper/train.py Normal file → Executable file
View File

@ -51,26 +51,24 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from typing import List from typing import List
#from aishell import AIShell
#from asr_datamodule import AsrDataModule
from asr_datamodule import AishellAsrDataModule from asr_datamodule import AishellAsrDataModule
#from decoder import Decoder
#from joiner import Joiner
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
#from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.functional import pad as pad_tensor from torch.nn.functional import pad as pad_tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
#from zipformer import Zipformer
from icefall import diagnostics from icefall import diagnostics
#from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -109,13 +107,6 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument( parser.add_argument(
"--tensorboard", "--tensorboard",
type=str2bool, type=str2bool,
@ -209,19 +200,6 @@ def get_parser():
help="Add hooks to check for infinite module outputs and gradients.", help="Add hooks to check for infinite module outputs and gradients.",
) )
parser.add_argument(
"--save-every-n",
type=int,
default=4000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
""",
)
parser.add_argument( parser.add_argument(
"--keep-last-k", "--keep-last-k",
type=int, type=int,
@ -314,11 +292,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 999999999999999999, # For the 100h subset, use 800 "valid_interval": 9999999,
# parameters for zipformer
"feature_dim": 80,
"subsampling_factor": 4, # not passed in, this is fixed.
"warm_step": 100,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -491,26 +465,25 @@ def compute_loss(
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T) feature = feature.transpose(1, 2) # (N, C, T)
# pad feature from B,80,T to B,80,3000
#feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train batch_idx_train = params.batch_idx_train
warm_step = params.warm_step
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
# remove spaces in texts # remove spaces in texts
texts = [text.replace(" ", "") for text in texts] texts = [text.replace(" ", "") for text in texts]
#print(texts)
text_tokens_list = [list(tokenizer.sot_sequence_including_notimestamps) + tokenizer.encode(text) + [tokenizer.eot] for text in texts] text_tokens_list = [list(tokenizer.sot_sequence_including_notimestamps) + tokenizer.encode(text) + [tokenizer.eot] for text in texts]
# convert it to torch tensor # convert it to torch tensor
text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list] text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list]
# 50256 is the index of <pad> for all whisper models
prev_outputs_tokens = _batch_tensors( prev_outputs_tokens = _batch_tensors(
[tokens[:-1] for tokens in text_tokens_list], pad_value=50256 [tokens[:-1] for tokens in text_tokens_list], pad_value=50256
) )
@ -522,9 +495,10 @@ def compute_loss(
) )
decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum") decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum")
# ignore the first 3 tokens, which are always <sos>, <lang_id>, <transcibe>
ignore_prefix_size = 3 ignore_prefix_size = 3
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
encoder_out = model.encoder(feature) encoder_out = model.encoder(feature)
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
loss = decoder_criterion(text_logits, target_tokens.to(device)) loss = decoder_criterion(text_logits, target_tokens.to(device))
@ -697,27 +671,6 @@ def train_one_epoch(
model_avg=model_avg, model_avg=model_avg,
) )
# if (
# params.batch_idx_train > 0
# and params.batch_idx_train % params.save_every_n == 0
# ):
# save_checkpoint_with_global_batch_idx(
# out_dir=params.exp_dir,
# global_batch_idx=params.batch_idx_train,
# model=model,
# model_avg=model_avg,
# params=params,
# optimizer=optimizer,
# scheduler=scheduler,
# sampler=train_dl.sampler,
# scaler=scaler,
# rank=rank,
# )
# remove_checkpoints(
# out_dir=params.exp_dir,
# topk=params.keep_last_k,
# rank=rank,
# )
if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed: if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed:
# If the grad scale was less than 1, try increasing it. The _growth_interval # If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different # of the grad scaler is configurable, but we can't configure it to have different
@ -791,8 +744,7 @@ def run(rank, world_size, args):
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
# TODO download model only on rank 0
# TODO may change compute validation loss using multiple cards
model = load_model(params.model_name) model = load_model(params.model_name)
del model.alignment_heads del model.alignment_heads
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
@ -802,7 +754,6 @@ def run(rank, world_size, args):
model.is_multilingual, num_languages=model.num_languages, language="zh", task="transcribe" model.is_multilingual, num_languages=model.num_languages, language="zh", task="transcribe"
) )
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None model_avg: Optional[nn.Module] = None
if rank == 0: if rank == 0:
# model_avg is only used with rank 0 # model_avg is only used with rank 0
@ -863,9 +814,8 @@ def run(rank, world_size, args):
else: else:
sampler_state_dict = None sampler_state_dict = None
train_dl = aishell.train_dataloaders(aishell.train_cuts())
train_dl = aishell.train_dataloaders(aishell.train_cuts(), rank=rank, world_size=world_size) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts(), rank=rank, world_size=world_size)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:

View File

@ -28,7 +28,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter, MonoCut, combine
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
@ -45,11 +45,10 @@ def is_cut_long(c: MonoCut) -> bool:
return c.duration > 5 return c.duration > 5
def compute_fbank_musan(): def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
num_mel_bins = 80
dataset_parts = ( dataset_parts = (
"music", "music",
@ -81,7 +80,10 @@ def compute_fbank_musan():
logging.info("Extracting features for Musan") logging.info("Extracting features for Musan")
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds # create chunks of Musan with duration 5 - 10 seconds
@ -101,9 +103,26 @@ def compute_fbank_musan():
) )
musan_cuts.to_file(musan_cuts_path) musan_cuts.to_file(musan_cuts_path)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan() compute_fbank_musan(
num_mel_bins=args.num_mel_bins, whisper_fbank=args.whisper_fbank
)