mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 13:34:20 +00:00
support using extracted codebook indexes
This commit is contained in:
parent
1ed96824a0
commit
496abc30c0
99
egs/librispeech/ASR/distillation_with_hubert.sh
Normal file → Executable file
99
egs/librispeech/ASR/distillation_with_hubert.sh
Normal file → Executable file
@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# A short introduction about distillation framework.
|
||||
#
|
||||
# A typical traditional distillation method is
|
||||
@ -22,7 +24,11 @@
|
||||
# For example command
|
||||
# bash distillation_with_hubert.sh 0
|
||||
# will download hubert model.
|
||||
stage=$1
|
||||
|
||||
set -x
|
||||
|
||||
stage=2
|
||||
stop_stage=3
|
||||
|
||||
# Set the GPUs available.
|
||||
# This script requires at least one GPU.
|
||||
@ -33,10 +39,32 @@ stage=$1
|
||||
# export CUDA_VISIBLE_DEVICES="0"
|
||||
#
|
||||
# Suppose GPU 2,3,4,5 are available.
|
||||
export CUDA_VISIBLE_DEVICES="2,3,4,5"
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
exp_dir=./pruned_transducer_stateless6/exp
|
||||
mkdir -p $exp_dir
|
||||
|
||||
if [ $stage -eq 0 ]; then
|
||||
# full_libri can be "True" or "False"
|
||||
# If "True", the distillation will use full librispeech dataset.
|
||||
full_libri=False
|
||||
|
||||
# use_extracted_codebook can be "True" or "False"
|
||||
# If "True", stage 0 and stage 1 would be skipped
|
||||
use_extracted_codebook=False
|
||||
|
||||
# teacher_model_id can be one of
|
||||
# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use.
|
||||
# "hubert_xtralarge_ll60k.pt" -> pretrained model without fintuing
|
||||
teacher_model_id=hubert_xtralarge_ll60k_finetune_ls960
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" == "True" ]; then
|
||||
log "Stage 0: Download HuBERT model"
|
||||
# Preparation stage.
|
||||
|
||||
# Install fairseq according to:
|
||||
@ -45,7 +73,7 @@ if [ $stage -eq 0 ]; then
|
||||
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used.
|
||||
has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)")
|
||||
if [ $has_fairseq == 'False' ]; then
|
||||
echo "Please install fairseq before running following stages"
|
||||
log "Please install fairseq before running following stages"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -56,42 +84,41 @@ if [ $stage -eq 0 ]; then
|
||||
|
||||
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
|
||||
if [ $has_quantization == 'False' ]; then
|
||||
echo "Please install quantization before running following stages"
|
||||
log "Please install quantization before running following stages"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Download hubert model."
|
||||
log "Download HuBERT model."
|
||||
# Parameters about model.
|
||||
exp_dir=./pruned_transducer_stateless6/exp/
|
||||
model_id=hubert_xtralarge_ll60k_finetune_ls960
|
||||
hubert_model_dir=${exp_dir}/hubert_models
|
||||
hubert_model=${hubert_model_dir}/${model_id}.pt
|
||||
hubert_model=${hubert_model_dir}/${teacher_model_id}.pt
|
||||
mkdir -p ${hubert_model_dir}
|
||||
# For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert
|
||||
if [ -f ${hubert_model} ]; then
|
||||
echo "hubert model alread exists."
|
||||
log "HuBERT model alread exists."
|
||||
else
|
||||
wget -c https://dl.fbaipublicfiles.com/hubert/${model_id} -P ${hubert_model}
|
||||
wget -c https://dl.fbaipublicfiles.com/hubert/${teacher_model_id}.pt -P ${hubert_model_dir}
|
||||
wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir}
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -d ./data/fbank ]; then
|
||||
echo "This script assumes ./data/fbank is already generated by prepare.sh"
|
||||
log "This script assumes ./data/fbank is already generated by prepare.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -eq 1 ]; then
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] && [ ! "$use_extracted_codebook" == "True" ]; then
|
||||
log "Stage 1: Verify that the downloaded HuBERT model is correct."
|
||||
# This stage is not directly used by codebook indexes extraction.
|
||||
# It is a method to "prove" that the downloaed hubert model
|
||||
# is inferenced in an correct way if WERs look like normal.
|
||||
# Expect WERs:
|
||||
# [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ]
|
||||
# [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ]
|
||||
./pruned_transducer_stateless6/hubert_decode.py
|
||||
./pruned_transducer_stateless6/hubert_decode.py --exp-dir $exp_dir
|
||||
fi
|
||||
|
||||
if [ $stage -eq 2 ]; then
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
# Analysis of disk usage:
|
||||
# With num_codebooks==8, each teacher embedding is quantized into
|
||||
# a sequence of eight 8-bit integers, i.e. only eight bytes are needed.
|
||||
@ -113,25 +140,59 @@ if [ $stage -eq 2 ]; then
|
||||
# During quantizer's training data(teacher embedding) and it's training,
|
||||
# only the first ONE GPU is used.
|
||||
# During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used.
|
||||
|
||||
if [ "$use_extracted_codebook" == "True" ]; then
|
||||
if [ ! "$teacher_model_id" == "hubert_xtralarge_ll60k_finetune_ls960" ]; then
|
||||
log "Currently we only uploaded codebook indexes from teacher model hubert_xtralarge_ll60k_finetune_ls960"
|
||||
exit 1
|
||||
fi
|
||||
mkdir -p $exp_dir/vq
|
||||
codebook_dir=$exp_dir/vq/$teacher_model_id
|
||||
mkdir -p codebook_dir
|
||||
codebook_download_dir=$exp_dir/download_codebook
|
||||
if [ -d $codebook_download_dir ]; then
|
||||
log "$download_codebook exists, you should remove it first."
|
||||
exit 1
|
||||
fi
|
||||
log "Downloading extracted codebook indexes to $codebook_download_dir"
|
||||
# Make sure you have git-lfs installed (https://git-lfs.github.com)
|
||||
git lfs install
|
||||
git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
|
||||
|
||||
mkdir -p data/vq_fbank
|
||||
mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/
|
||||
mkdir -p $codebook_dir/splits4
|
||||
mv $codebook_download_dir/*.h5 $codebook_dir/splits4/
|
||||
log "Remove $codebook_download_dir"
|
||||
rm -rf $codebook_download_dir
|
||||
fi
|
||||
|
||||
./pruned_transducer_stateless6/extract_codebook_index.py \
|
||||
--full-libri False
|
||||
--full-libri $full_libri \
|
||||
--exp-dir $exp_dir \
|
||||
--embedding-layer 36 \
|
||||
--num-utts 1000 \
|
||||
--num-codebooks 8 \
|
||||
--max-duration 100 \
|
||||
--teacher-model-id $teacher_model_id \
|
||||
--use-extracted-codebook $use_extracted_codebook
|
||||
fi
|
||||
|
||||
if [ $stage -eq 3 ]; then
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
# Example training script.
|
||||
# Note: it's better to set spec-aug-time-warpi-factor=-1
|
||||
WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}')
|
||||
./pruned_transducer_stateless6/train.py \
|
||||
--manifest-dir ./data/vq_fbank \
|
||||
--master-port 12359 \
|
||||
--full-libri False \
|
||||
--full-libri $full_libri \
|
||||
--spec-aug-time-warp-factor -1 \
|
||||
--max-duration 300 \
|
||||
--world-size ${WORLD_SIZE} \
|
||||
--num-epochs 20
|
||||
fi
|
||||
|
||||
if [ $stage -eq 4 ]; then
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
# Results should be similar to:
|
||||
# errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67
|
||||
# errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60
|
||||
|
@ -24,7 +24,7 @@ import torch
|
||||
from vq_utils import CodebookIndexExtractor
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from hubert_xlarge import HubertXlargeFineTuned
|
||||
from icefall.utils import AttributeDict
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -38,6 +38,13 @@ def get_parser():
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-extracted-codebook",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use the extracted codebook indexes.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -71,9 +78,13 @@ def main():
|
||||
params.world_size = world_size
|
||||
|
||||
extractor = CodebookIndexExtractor(params=params)
|
||||
extractor.extract_and_save_embedding()
|
||||
extractor.train_quantizer()
|
||||
extractor.extract_codebook_indexes()
|
||||
if not params.use_extracted_codebook:
|
||||
extractor.extract_and_save_embedding()
|
||||
extractor.train_quantizer()
|
||||
extractor.extract_codebook_indexes()
|
||||
|
||||
extractor.reuse_manifests()
|
||||
extractor.join_manifests()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -63,17 +63,15 @@ class CodebookIndexExtractor:
|
||||
setup_logger(f"{self.vq_dir}/log-vq_extraction")
|
||||
|
||||
def init_dirs(self):
|
||||
# TODO:
|
||||
# vq_dir is the root dir for quantizer:
|
||||
# training data/ quantizer / extracted codebook indexes
|
||||
# vq_dir is the root dir for quantization, containing:
|
||||
# training data, trained quantizer, and extracted codebook indexes
|
||||
self.vq_dir = (
|
||||
self.params.exp_dir / f"vq/{self.params.teacher_model_id}/"
|
||||
)
|
||||
self.vq_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# manifest_dir for :
|
||||
# splited original manifests,
|
||||
# extracted codebook indexes and their related manifests
|
||||
# manifest_dir contains:
|
||||
# splited original manifests, extracted codebook indexes with related manifests # noqa
|
||||
self.manifest_dir = self.vq_dir / f"splits{self.params.world_size}"
|
||||
self.manifest_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -137,6 +135,7 @@ class CodebookIndexExtractor:
|
||||
logging.warn(warn_message)
|
||||
return
|
||||
|
||||
logging.info("Start to extract embeddings for training the quantizer.")
|
||||
total_cuts = 0
|
||||
with NumpyHdf5Writer(self.embedding_file_path) as writer:
|
||||
for batch_idx, batch in enumerate(self.quantizer_train_dl):
|
||||
@ -189,14 +188,15 @@ class CodebookIndexExtractor:
|
||||
return
|
||||
|
||||
assert self.embedding_file_path.exists()
|
||||
logging.info("Start to train quantizer.")
|
||||
trainer = quantization.QuantizerTrainer(
|
||||
dim=self.params.embedding_dim,
|
||||
bytes_per_frame=self.params.num_codebooks,
|
||||
device=self.params.device,
|
||||
)
|
||||
train, valid = quantization.read_hdf5_data(self.embedding_file_path)
|
||||
B = 512 # Minibatch size, this is very arbitrary, it's close to what we used
|
||||
# when we tuned this method.
|
||||
B = 512 # Minibatch size, this is very arbitrary,
|
||||
# it's close to what we used when we tuned this method.
|
||||
|
||||
def minibatch_generator(data: torch.Tensor, repeat: bool):
|
||||
assert 3 * B < data.shape[0]
|
||||
@ -231,8 +231,10 @@ class CodebookIndexExtractor:
|
||||
os.system(f"{split_cmd}")
|
||||
|
||||
def join_manifests(self):
|
||||
"""TODO:"""
|
||||
|
||||
"""
|
||||
Join the vq manifest to the original manifest according to cut id.
|
||||
"""
|
||||
logging.info("Start to join manifest files.")
|
||||
for subset in self.params.subsets:
|
||||
vq_manifest_path = (
|
||||
self.dst_manifest_dir
|
||||
@ -254,6 +256,8 @@ class CodebookIndexExtractor:
|
||||
cut_ori.codebook_indexes = cut_vq.codebook_indexes
|
||||
|
||||
CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path)
|
||||
logging.info(f"Processed {subset}.")
|
||||
logging.info(f"Saved to {dst_vq_manifest_path}.")
|
||||
|
||||
def merge_vq_manifests(self):
|
||||
"""
|
||||
@ -303,7 +307,6 @@ class CodebookIndexExtractor:
|
||||
os.symlink(ori_manifest_path, dst_manifest_path)
|
||||
|
||||
def create_vq_fbank(self):
|
||||
self.reuse_manifests()
|
||||
self.merge_vq_manifests()
|
||||
|
||||
@cached_property
|
||||
@ -330,7 +333,7 @@ class CodebookIndexExtractor:
|
||||
else:
|
||||
ori_manifest_path = (
|
||||
self.manifest_dir
|
||||
/ f"librispeech_cuts_train-{subset}.{self.params.manifest_index}.jsonl.gz"
|
||||
/ f"librispeech_cuts_train-{subset}.{self.params.manifest_index}.jsonl.gz" # noqa
|
||||
)
|
||||
|
||||
cuts = load_manifest(ori_manifest_path)
|
||||
@ -343,6 +346,7 @@ class CodebookIndexExtractor:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def extract_codebook_indexes(self):
|
||||
logging.info("Start to extract codebook indexes.")
|
||||
if self.params.world_size == 1:
|
||||
self.extract_codebook_indexes_imp()
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user