clean codes

This commit is contained in:
Yuekai Zhang 2024-01-15 20:40:44 +08:00
parent eea46458c5
commit 557b35cefc
8 changed files with 266 additions and 575 deletions

View File

@ -29,7 +29,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@ -42,7 +42,7 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -68,8 +68,10 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
@ -111,6 +113,12 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args()
@ -121,5 +129,5 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_aishell(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
)

View File

@ -1,125 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the aishell dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
dataset_parts = (
"train",
"dev",
"test",
)
prefix = "aishell"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_aishell(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
)

View File

@ -1,109 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the musan dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter, MonoCut, combine
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def is_cut_long(c: MonoCut) -> bool:
return c.duration > 5
def compute_fbank_musan():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
num_mel_bins = 80
dataset_parts = (
"music",
"speech",
"noise",
)
prefix = "musan"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
musan_cuts_path = output_dir / "musan_cuts.jsonl.gz"
if musan_cuts_path.is_file():
logging.info(f"{musan_cuts_path} already exists - skipping")
return
logging.info("Extracting features for Musan")
extractor = WhisperFbank(WhisperFbankConfig(device='cuda'))
with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (
CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values())
)
.cut_into_windows(10.0)
.filter(is_cut_long)
.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/musan_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
)
musan_cuts.to_file(musan_cuts_path)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan()

View File

@ -1,5 +1,5 @@
#!/usr/bin/env bash
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
@ -83,9 +83,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
#
# ln -sfv /path/to/musan $dl_dir/musan
#
# if [ ! -d $dl_dir/musan ]; then
# lhotse download musan $dl_dir
# fi
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
@ -99,17 +99,17 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi
fi
# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# log "Stage 2: Prepare musan manifest"
# # We assume that you have downloaded the musan corpus
# # to data/musan
# if [ ! -f data/manifests/.musan_manifests.done ]; then
# log "It may take 6 minutes"
# mkdir -p data/manifests
# lhotse prepare musan $dl_dir/musan data/manifests
# touch data/manifests/.musan_manifests.done
# fi
# fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
if [ ! -f data/manifests/.musan_manifests.done ]; then
log "It may take 6 minutes"
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan_manifests.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for aishell"
@ -120,56 +120,29 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
fi
# if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
# log "Stage 30: Compute whisper fbank for aishell"
# if [ ! -f data/fbank/.aishell.done ]; then
# mkdir -p data/fbank
# ./local/compute_whisper_fbank_aishell.py --perturb-speed True
# touch data/fbank/.aishell.done
# fi
# fi
if [ $stage -le 300 ] && [ $stop_stage -ge 300 ]; then
log "Stage 30: Compute whisper fbank for aishell"
if [ ! -f data/fbank/.aishell.done ]; then
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then
mkdir -p data/fbank
./local/compute_whisper_fbank_aishell.py --perturb-speed True --num-mel-bins 128
touch data/fbank/.aishell.done
./local/compute_fbank_musan.py
touch data/fbank/.msuan.done
fi
fi
# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
# log "Stage 4: Compute fbank for musan"
# if [ ! -f data/fbank/.msuan.done ]; then
# mkdir -p data/fbank
# ./local/compute_fbank_musan.py
# touch data/fbank/.msuan.done
# fi
# fi
lang_phone_dir=data/lang_phone
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
mkdir -p $lang_phone_dir
# if [ $stage -le 40 ] && [ $stop_stage -ge 40 ]; then
# log "Stage 4: Compute fbank for musan"
# if [ ! -f data/fbank/.msuan.done ]; then
# mkdir -p data/fbank
# ./local/compute_whisper_fbank_musan.py
# touch data/fbank/.msuan.done
# fi
# fi
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/aishell/resource_aishell/lexicon.txt |
sort | uniq > $lang_phone_dir/lexicon.txt
# lang_phone_dir=data/lang_phone
# if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
# log "Stage 5: Prepare phone based lang"
# mkdir -p $lang_phone_dir
./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir
# (echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
# cat - $dl_dir/aishell/resource_aishell/lexicon.txt |
# sort | uniq > $lang_phone_dir/lexicon.txt
# ./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir
# if [ ! -f $lang_phone_dir/L_disambig.pt ]; then
# ./local/prepare_lang.py --lang-dir $lang_phone_dir
# fi
if [ ! -f $lang_phone_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_phone_dir
fi
# Train a bigram P for MMI training
@ -182,93 +155,93 @@ fi
cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
fi
# if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then
# ./local/convert_transcript_words_to_tokens.py \
# --lexicon $lang_phone_dir/uniq_lexicon.txt \
# --transcript $lang_phone_dir/transcript_words.txt \
# --oov "<UNK>" \
# > $lang_phone_dir/transcript_tokens.txt
# fi
if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then
./local/convert_transcript_words_to_tokens.py \
--lexicon $lang_phone_dir/uniq_lexicon.txt \
--transcript $lang_phone_dir/transcript_words.txt \
--oov "<UNK>" \
> $lang_phone_dir/transcript_tokens.txt
fi
# if [ ! -f $lang_phone_dir/P.arpa ]; then
# ./shared/make_kn_lm.py \
# -ngram-order 2 \
# -text $lang_phone_dir/transcript_tokens.txt \
# -lm $lang_phone_dir/P.arpa
# fi
if [ ! -f $lang_phone_dir/P.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 2 \
-text $lang_phone_dir/transcript_tokens.txt \
-lm $lang_phone_dir/P.arpa
fi
# if [ ! -f $lang_phone_dir/P.fst.txt ]; then
# python3 -m kaldilm \
# --read-symbol-table="$lang_phone_dir/tokens.txt" \
# --disambig-symbol='#0' \
# --max-order=2 \
# $lang_phone_dir/P.arpa > $lang_phone_dir/P.fst.txt
# fi
# fi
if [ ! -f $lang_phone_dir/P.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_phone_dir/tokens.txt" \
--disambig-symbol='#0' \
--max-order=2 \
$lang_phone_dir/P.arpa > $lang_phone_dir/P.fst.txt
fi
fi
# lang_char_dir=data/lang_char
# if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
# log "Stage 6: Prepare char based lang"
# mkdir -p $lang_char_dir
# # We reuse words.txt from phone based lexicon
# # so that the two can share G.pt later.
lang_char_dir=data/lang_char
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
mkdir -p $lang_char_dir
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
# # The transcripts in training set, generated in stage 5
# cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt
# The transcripts in training set, generated in stage 5
cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt
# cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt |
# cut -d " " -f 2- > $lang_char_dir/text
cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt |
cut -d " " -f 2- > $lang_char_dir/text
# (echo '<eps> 0'; echo '!SIL 1'; echo '<SPOKEN_NOISE> 2'; echo '<UNK> 3';) \
# > $lang_char_dir/words.txt
(echo '<eps> 0'; echo '!SIL 1'; echo '<SPOKEN_NOISE> 2'; echo '<UNK> 3';) \
> $lang_char_dir/words.txt
# cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
# | awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt
cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
| awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt
# num_lines=$(< $lang_char_dir/words.txt wc -l)
# (echo "#0 $num_lines"; echo "<s> $(($num_lines + 1))"; echo "</s> $(($num_lines + 2))";) \
# >> $lang_char_dir/words.txt
num_lines=$(< $lang_char_dir/words.txt wc -l)
(echo "#0 $num_lines"; echo "<s> $(($num_lines + 1))"; echo "</s> $(($num_lines + 2))";) \
>> $lang_char_dir/words.txt
# if [ ! -f $lang_char_dir/L_disambig.pt ]; then
# ./local/prepare_char.py --lang-dir $lang_char_dir
# fi
# fi
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
./local/prepare_char.py --lang-dir $lang_char_dir
fi
fi
# if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
# log "Stage 7: Prepare Byte BPE based lang"
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare Byte BPE based lang"
# for vocab_size in ${vocab_sizes[@]}; do
# lang_dir=data/lang_bbpe_${vocab_size}
# mkdir -p $lang_dir
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bbpe_${vocab_size}
mkdir -p $lang_dir
# cp $lang_char_dir/words.txt $lang_dir
# cp $lang_char_dir/text $lang_dir
cp $lang_char_dir/words.txt $lang_dir
cp $lang_char_dir/text $lang_dir
# if [ ! -f $lang_dir/bbpe.model ]; then
# ./local/train_bbpe_model.py \
# --lang-dir $lang_dir \
# --vocab-size $vocab_size \
# --transcript $lang_dir/text
# fi
if [ ! -f $lang_dir/bbpe.model ]; then
./local/train_bbpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/text
fi
# if [ ! -f $lang_dir/L_disambig.pt ]; then
# ./local/prepare_lang_bbpe.py --lang-dir $lang_dir
# fi
# done
# fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bbpe.py --lang-dir $lang_dir
fi
done
fi
# if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
# log "Stage 8: Prepare G"
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Prepare G"
# mkdir -p data/lm
mkdir -p data/lm
# # Train LM on transcripts
# if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
# python3 ./shared/make_kn_lm.py \
# -ngram-order 3 \
# -text $lang_char_dir/transcript_words.txt \
# -lm data/lm/3-gram.unpruned.arpa
# fi
# Train LM on transcripts
if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
python3 ./shared/make_kn_lm.py \
-ngram-order 3 \
-text $lang_char_dir/transcript_words.txt \
-lm data/lm/3-gram.unpruned.arpa
fi
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
@ -294,112 +267,124 @@ fi
fi
fi
# if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
# log "Stage 9: Compile LG & HLG"
# ./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
# ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char
# for vocab_size in ${vocab_sizes[@]}; do
# lang_dir=data/lang_bbpe_${vocab_size}
# ./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char
# done
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Compile LG & HLG"
./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bbpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char
done
# ./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
# ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char
# for vocab_size in ${vocab_sizes[@]}; do
# lang_dir=data/lang_bbpe_${vocab_size}
# ./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char
# done
# fi
./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bbpe_${vocab_size}
./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char
done
fi
# if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
# log "Stage 10: Generate LM training data"
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Generate LM training data"
# log "Processing char based data"
# out_dir=data/lm_training_char
# mkdir -p $out_dir $dl_dir/lm
log "Processing char based data"
out_dir=data/lm_training_char
mkdir -p $out_dir $dl_dir/lm
# if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then
# cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt
# fi
if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then
cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt
fi
# # training words
# ./local/prepare_char_lm_training_data.py \
# --lang-char data/lang_char \
# --lm-data $dl_dir/lm/aishell-train-word.txt \
# --lm-archive $out_dir/lm_data.pt
# training words
./local/prepare_char_lm_training_data.py \
--lang-char data/lang_char \
--lm-data $dl_dir/lm/aishell-train-word.txt \
--lm-archive $out_dir/lm_data.pt
# # valid words
# if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then
# aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
# aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid
# find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid
# awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text |
# cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt
# fi
# valid words
if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid
find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text |
cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt
fi
# ./local/prepare_char_lm_training_data.py \
# --lang-char data/lang_char \
# --lm-data $dl_dir/lm/aishell-valid-word.txt \
# --lm-archive $out_dir/lm_data_valid.pt
./local/prepare_char_lm_training_data.py \
--lang-char data/lang_char \
--lm-data $dl_dir/lm/aishell-valid-word.txt \
--lm-archive $out_dir/lm_data_valid.pt
# # test words
# if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then
# aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
# aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid
# find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid
# awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text |
# cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt
# fi
# test words
if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid
find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text |
cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt
fi
# ./local/prepare_char_lm_training_data.py \
# --lang-char data/lang_char \
# --lm-data $dl_dir/lm/aishell-test-word.txt \
# --lm-archive $out_dir/lm_data_test.pt
# fi
./local/prepare_char_lm_training_data.py \
--lang-char data/lang_char \
--lm-data $dl_dir/lm/aishell-test-word.txt \
--lm-archive $out_dir/lm_data_test.pt
fi
# if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
# log "Stage 11: Sort LM training data"
# # Sort LM training data by sentence length in descending order
# # for ease of training.
# #
# # Sentence length equals to the number of tokens
# # in a sentence.
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Sort LM training data"
# Sort LM training data by sentence length in descending order
# for ease of training.
#
# Sentence length equals to the number of tokens
# in a sentence.
# out_dir=data/lm_training_char
# mkdir -p $out_dir
# ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/
out_dir=data/lm_training_char
mkdir -p $out_dir
ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/
# ./local/sort_lm_training_data.py \
# --in-lm-data $out_dir/lm_data.pt \
# --out-lm-data $out_dir/sorted_lm_data.pt \
# --out-statistics $out_dir/statistics.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data.pt \
--out-lm-data $out_dir/sorted_lm_data.pt \
--out-statistics $out_dir/statistics.txt
# ./local/sort_lm_training_data.py \
# --in-lm-data $out_dir/lm_data_valid.pt \
# --out-lm-data $out_dir/sorted_lm_data-valid.pt \
# --out-statistics $out_dir/statistics-valid.txt
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data_valid.pt \
--out-lm-data $out_dir/sorted_lm_data-valid.pt \
--out-statistics $out_dir/statistics-valid.txt
# ./local/sort_lm_training_data.py \
# --in-lm-data $out_dir/lm_data_test.pt \
# --out-lm-data $out_dir/sorted_lm_data-test.pt \
# --out-statistics $out_dir/statistics-test.txt
# fi
./local/sort_lm_training_data.py \
--in-lm-data $out_dir/lm_data_test.pt \
--out-lm-data $out_dir/sorted_lm_data-test.pt \
--out-statistics $out_dir/statistics-test.txt
fi
# if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
# log "Stage 11: Train RNN LM model"
# python ../../../icefall/rnn_lm/train.py \
# --start-epoch 0 \
# --world-size 1 \
# --num-epochs 20 \
# --use-fp16 0 \
# --embedding-dim 512 \
# --hidden-dim 512 \
# --num-layers 2 \
# --batch-size 400 \
# --exp-dir rnnlm_char/exp \
# --lm-data $out_dir/sorted_lm_data.pt \
# --lm-data-valid $out_dir/sorted_lm_data-valid.pt \
# --vocab-size 4336 \
# --master-port 12345
# fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 11: Train RNN LM model"
python ../../../icefall/rnn_lm/train.py \
--start-epoch 0 \
--world-size 1 \
--num-epochs 20 \
--use-fp16 0 \
--embedding-dim 512 \
--hidden-dim 512 \
--num-layers 2 \
--batch-size 400 \
--exp-dir rnnlm_char/exp \
--lm-data $out_dir/sorted_lm_data.pt \
--lm-data-valid $out_dir/sorted_lm_data-valid.pt \
--vocab-size 4336 \
--master-port 12345
fi
# whisper large-v3 using 128 mel bins, others using 80 mel bins
whisper_mel_bins=80
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
log "Stage 30: Compute ${whisper_mel_bins} dim fbank for whisper model fine-tuning"
if [ ! -f data/fbank/.aishell.whisper.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.aishell.whisper.done
fi
fi

40
egs/aishell/ASR/whisper/decode.py Normal file → Executable file
View File

@ -115,10 +115,8 @@ def remove_punctuation(text: str or List[str]):
result_text.append(t)
return result_text
else:
raise Exception(f'不支持该类型{type(text)}')
raise Exception(f'Not support type {type(text)}')
# 将繁体中文总成简体中文
def to_simple(text: str or List[str]):
if isinstance(text, str):
text = convert(text, 'zh-cn')
@ -130,7 +128,7 @@ def to_simple(text: str or List[str]):
result_text.append(t)
return result_text
else:
raise Exception(f'不支持该类型{type(text)}')
raise Exception(f'Not support type{type(text)}')
def get_parser():
parser = argparse.ArgumentParser(
@ -192,27 +190,11 @@ def get_parser():
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"subsampling_factor": 4,
"feature_dim": 80,
"nhead": 4,
"attention_dim": 512,
"num_encoder_layers": 12,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"search_beam": 20,
"output_beam": 7,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
@ -264,12 +246,6 @@ def decode_one_batch(
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device, dtype=dtype).transpose(1, 2)
# pad feature to T = 3000
#T = 3000
#if feature.shape[2] < T:
# feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
print(feature.shape,23333)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"]
@ -281,7 +257,6 @@ def decode_one_batch(
hyps = to_simple(hyps)
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
print(hyps, 233333333)
key = "beam-search"
@ -467,17 +442,14 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
#test_sets = ["test"]
#test_dls = [test_dl]
test_sets = ["valid"]
test_dls = [valid_dl]
test_dl = aishell.test_dataloaders(aishell.test_cuts())
test_sets = ["valid", "test"]
test_dls = [valid_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,

13
egs/aishell/ASR/whisper/model.py Normal file → Executable file
View File

@ -168,10 +168,8 @@ class AudioEncoder(nn.Module):
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
# change whisper to process audio with any length
x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype)
#x = (x + self.positional_embedding).to(x.dtype)
for block in self.blocks:
x = block(x)
@ -224,7 +222,6 @@ class TextDecoder(nn.Module):
return logits
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()
@ -315,7 +312,6 @@ class Whisper(nn.Module):
self.decoder.apply(install_hooks)
return cache, hooks
#detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function
@ -432,9 +428,4 @@ def load_model(
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
# if alignment_heads is not None:
# model.set_alignment_heads(alignment_heads)
return model.to(device)
return model.to(device)

80
egs/aishell/ASR/whisper/train.py Normal file → Executable file
View File

@ -51,26 +51,24 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from typing import List
#from aishell import AIShell
#from asr_datamodule import AsrDataModule
from asr_datamodule import AishellAsrDataModule
#from decoder import Decoder
#from joiner import Joiner
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
#from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.functional import pad as pad_tensor
from torch.utils.tensorboard import SummaryWriter
#from zipformer import Zipformer
from icefall import diagnostics
#from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@ -109,13 +107,6 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
@ -209,19 +200,6 @@ def get_parser():
help="Add hooks to check for infinite module outputs and gradients.",
)
parser.add_argument(
"--save-every-n",
type=int,
default=4000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
""",
)
parser.add_argument(
"--keep-last-k",
type=int,
@ -314,11 +292,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 999999999999999999, # For the 100h subset, use 800
# parameters for zipformer
"feature_dim": 80,
"subsampling_factor": 4, # not passed in, this is fixed.
"warm_step": 100,
"valid_interval": 9999999,
"env_info": get_env_info(),
}
)
@ -491,26 +465,25 @@ def compute_loss(
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T)
# pad feature from B,80,T to B,80,3000
#feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
# remove spaces in texts
texts = [text.replace(" ", "") for text in texts]
#print(texts)
text_tokens_list = [list(tokenizer.sot_sequence_including_notimestamps) + tokenizer.encode(text) + [tokenizer.eot] for text in texts]
# convert it to torch tensor
text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list]
# 50256 is the index of <pad> for all whisper models
prev_outputs_tokens = _batch_tensors(
[tokens[:-1] for tokens in text_tokens_list], pad_value=50256
)
@ -522,9 +495,10 @@ def compute_loss(
)
decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum")
# ignore the first 3 tokens, which are always <sos>, <lang_id>, <transcibe>
ignore_prefix_size = 3
with torch.set_grad_enabled(is_training):
encoder_out = model.encoder(feature)
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
loss = decoder_criterion(text_logits, target_tokens.to(device))
@ -697,27 +671,6 @@ def train_one_epoch(
model_avg=model_avg,
)
# if (
# params.batch_idx_train > 0
# and params.batch_idx_train % params.save_every_n == 0
# ):
# save_checkpoint_with_global_batch_idx(
# out_dir=params.exp_dir,
# global_batch_idx=params.batch_idx_train,
# model=model,
# model_avg=model_avg,
# params=params,
# optimizer=optimizer,
# scheduler=scheduler,
# sampler=train_dl.sampler,
# scaler=scaler,
# rank=rank,
# )
# remove_checkpoints(
# out_dir=params.exp_dir,
# topk=params.keep_last_k,
# rank=rank,
# )
if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
@ -791,8 +744,7 @@ def run(rank, world_size, args):
logging.info(params)
logging.info("About to create model")
# TODO download model only on rank 0
# TODO may change compute validation loss using multiple cards
model = load_model(params.model_name)
del model.alignment_heads
num_param = sum([p.numel() for p in model.parameters()])
@ -802,7 +754,6 @@ def run(rank, world_size, args):
model.is_multilingual, num_languages=model.num_languages, language="zh", task="transcribe"
)
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
@ -863,9 +814,8 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
train_dl = aishell.train_dataloaders(aishell.train_cuts(), rank=rank, world_size=world_size)
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts(), rank=rank, world_size=world_size)
train_dl = aishell.train_dataloaders(aishell.train_cuts())
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:

View File

@ -28,7 +28,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter, MonoCut, combine
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -45,11 +45,10 @@ def is_cut_long(c: MonoCut) -> bool:
return c.duration > 5
def compute_fbank_musan():
def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
num_mel_bins = 80
dataset_parts = (
"music",
@ -81,7 +80,10 @@ def compute_fbank_musan():
logging.info("Extracting features for Musan")
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds
@ -101,9 +103,26 @@ def compute_fbank_musan():
)
musan_cuts.to_file(musan_cuts_path)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan()
compute_fbank_musan(
num_mel_bins=args.num_mel_bins, whisper_fbank=args.whisper_fbank
)