support using extracted codebook indexes

This commit is contained in:
yaozengwei 2022-06-16 21:31:26 +08:00
parent 1ed96824a0
commit 496abc30c0
3 changed files with 111 additions and 35 deletions

99
egs/librispeech/ASR/distillation_with_hubert.sh Normal file → Executable file
View File

@ -1,3 +1,5 @@
#!/usr/bin/env bash
#
# A short introduction about distillation framework.
#
# A typical traditional distillation method is
@ -22,7 +24,11 @@
# For example command
# bash distillation_with_hubert.sh 0
# will download hubert model.
stage=$1
set -x
stage=2
stop_stage=3
# Set the GPUs available.
# This script requires at least one GPU.
@ -33,10 +39,32 @@ stage=$1
# export CUDA_VISIBLE_DEVICES="0"
#
# Suppose GPU 2,3,4,5 are available.
export CUDA_VISIBLE_DEVICES="2,3,4,5"
export CUDA_VISIBLE_DEVICES="0,1,2,3"
exp_dir=./pruned_transducer_stateless6/exp
mkdir -p $exp_dir
if [ $stage -eq 0 ]; then
# full_libri can be "True" or "False"
# If "True", the distillation will use full librispeech dataset.
full_libri=False
# use_extracted_codebook can be "True" or "False"
# If "True", stage 0 and stage 1 would be skipped
use_extracted_codebook=False
# teacher_model_id can be one of
# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use.
# "hubert_xtralarge_ll60k.pt" -> pretrained model without fintuing
teacher_model_id=hubert_xtralarge_ll60k_finetune_ls960
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" == "True" ]; then
log "Stage 0: Download HuBERT model"
# Preparation stage.
# Install fairseq according to:
@ -45,7 +73,7 @@ if [ $stage -eq 0 ]; then
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used.
has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)")
if [ $has_fairseq == 'False' ]; then
echo "Please install fairseq before running following stages"
log "Please install fairseq before running following stages"
exit 1
fi
@ -56,42 +84,41 @@ if [ $stage -eq 0 ]; then
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
if [ $has_quantization == 'False' ]; then
echo "Please install quantization before running following stages"
log "Please install quantization before running following stages"
exit 1
fi
echo "Download hubert model."
log "Download HuBERT model."
# Parameters about model.
exp_dir=./pruned_transducer_stateless6/exp/
model_id=hubert_xtralarge_ll60k_finetune_ls960
hubert_model_dir=${exp_dir}/hubert_models
hubert_model=${hubert_model_dir}/${model_id}.pt
hubert_model=${hubert_model_dir}/${teacher_model_id}.pt
mkdir -p ${hubert_model_dir}
# For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert
if [ -f ${hubert_model} ]; then
echo "hubert model alread exists."
log "HuBERT model alread exists."
else
wget -c https://dl.fbaipublicfiles.com/hubert/${model_id} -P ${hubert_model}
wget -c https://dl.fbaipublicfiles.com/hubert/${teacher_model_id}.pt -P ${hubert_model_dir}
wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir}
fi
fi
if [ ! -d ./data/fbank ]; then
echo "This script assumes ./data/fbank is already generated by prepare.sh"
log "This script assumes ./data/fbank is already generated by prepare.sh"
exit 1
fi
if [ $stage -eq 1 ]; then
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] && [ ! "$use_extracted_codebook" == "True" ]; then
log "Stage 1: Verify that the downloaded HuBERT model is correct."
# This stage is not directly used by codebook indexes extraction.
# It is a method to "prove" that the downloaed hubert model
# is inferenced in an correct way if WERs look like normal.
# Expect WERs:
# [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ]
# [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ]
./pruned_transducer_stateless6/hubert_decode.py
./pruned_transducer_stateless6/hubert_decode.py --exp-dir $exp_dir
fi
if [ $stage -eq 2 ]; then
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# Analysis of disk usage:
# With num_codebooks==8, each teacher embedding is quantized into
# a sequence of eight 8-bit integers, i.e. only eight bytes are needed.
@ -113,25 +140,59 @@ if [ $stage -eq 2 ]; then
# During quantizer's training data(teacher embedding) and it's training,
# only the first ONE GPU is used.
# During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used.
if [ "$use_extracted_codebook" == "True" ]; then
if [ ! "$teacher_model_id" == "hubert_xtralarge_ll60k_finetune_ls960" ]; then
log "Currently we only uploaded codebook indexes from teacher model hubert_xtralarge_ll60k_finetune_ls960"
exit 1
fi
mkdir -p $exp_dir/vq
codebook_dir=$exp_dir/vq/$teacher_model_id
mkdir -p codebook_dir
codebook_download_dir=$exp_dir/download_codebook
if [ -d $codebook_download_dir ]; then
log "$download_codebook exists, you should remove it first."
exit 1
fi
log "Downloading extracted codebook indexes to $codebook_download_dir"
# Make sure you have git-lfs installed (https://git-lfs.github.com)
git lfs install
git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
mkdir -p data/vq_fbank
mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/
mkdir -p $codebook_dir/splits4
mv $codebook_download_dir/*.h5 $codebook_dir/splits4/
log "Remove $codebook_download_dir"
rm -rf $codebook_download_dir
fi
./pruned_transducer_stateless6/extract_codebook_index.py \
--full-libri False
--full-libri $full_libri \
--exp-dir $exp_dir \
--embedding-layer 36 \
--num-utts 1000 \
--num-codebooks 8 \
--max-duration 100 \
--teacher-model-id $teacher_model_id \
--use-extracted-codebook $use_extracted_codebook
fi
if [ $stage -eq 3 ]; then
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
# Example training script.
# Note: it's better to set spec-aug-time-warpi-factor=-1
WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}')
./pruned_transducer_stateless6/train.py \
--manifest-dir ./data/vq_fbank \
--master-port 12359 \
--full-libri False \
--full-libri $full_libri \
--spec-aug-time-warp-factor -1 \
--max-duration 300 \
--world-size ${WORLD_SIZE} \
--num-epochs 20
fi
if [ $stage -eq 4 ]; then
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
# Results should be similar to:
# errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67
# errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60

View File

@ -24,7 +24,7 @@ import torch
from vq_utils import CodebookIndexExtractor
from asr_datamodule import LibriSpeechAsrDataModule
from hubert_xlarge import HubertXlargeFineTuned
from icefall.utils import AttributeDict
from icefall.utils import AttributeDict, str2bool
def get_parser():
@ -38,6 +38,13 @@ def get_parser():
help="The experiment dir",
)
parser.add_argument(
"--use-extracted-codebook",
type=str2bool,
default=False,
help="Whether to use the extracted codebook indexes.",
)
return parser
@ -71,9 +78,13 @@ def main():
params.world_size = world_size
extractor = CodebookIndexExtractor(params=params)
extractor.extract_and_save_embedding()
extractor.train_quantizer()
extractor.extract_codebook_indexes()
if not params.use_extracted_codebook:
extractor.extract_and_save_embedding()
extractor.train_quantizer()
extractor.extract_codebook_indexes()
extractor.reuse_manifests()
extractor.join_manifests()
if __name__ == "__main__":

View File

@ -63,17 +63,15 @@ class CodebookIndexExtractor:
setup_logger(f"{self.vq_dir}/log-vq_extraction")
def init_dirs(self):
# TODO:
# vq_dir is the root dir for quantizer:
# training data/ quantizer / extracted codebook indexes
# vq_dir is the root dir for quantization, containing:
# training data, trained quantizer, and extracted codebook indexes
self.vq_dir = (
self.params.exp_dir / f"vq/{self.params.teacher_model_id}/"
)
self.vq_dir.mkdir(parents=True, exist_ok=True)
# manifest_dir for :
# splited original manifests,
# extracted codebook indexes and their related manifests
# manifest_dir contains:
# splited original manifests, extracted codebook indexes with related manifests # noqa
self.manifest_dir = self.vq_dir / f"splits{self.params.world_size}"
self.manifest_dir.mkdir(parents=True, exist_ok=True)
@ -137,6 +135,7 @@ class CodebookIndexExtractor:
logging.warn(warn_message)
return
logging.info("Start to extract embeddings for training the quantizer.")
total_cuts = 0
with NumpyHdf5Writer(self.embedding_file_path) as writer:
for batch_idx, batch in enumerate(self.quantizer_train_dl):
@ -189,14 +188,15 @@ class CodebookIndexExtractor:
return
assert self.embedding_file_path.exists()
logging.info("Start to train quantizer.")
trainer = quantization.QuantizerTrainer(
dim=self.params.embedding_dim,
bytes_per_frame=self.params.num_codebooks,
device=self.params.device,
)
train, valid = quantization.read_hdf5_data(self.embedding_file_path)
B = 512 # Minibatch size, this is very arbitrary, it's close to what we used
# when we tuned this method.
B = 512 # Minibatch size, this is very arbitrary,
# it's close to what we used when we tuned this method.
def minibatch_generator(data: torch.Tensor, repeat: bool):
assert 3 * B < data.shape[0]
@ -231,8 +231,10 @@ class CodebookIndexExtractor:
os.system(f"{split_cmd}")
def join_manifests(self):
"""TODO:"""
"""
Join the vq manifest to the original manifest according to cut id.
"""
logging.info("Start to join manifest files.")
for subset in self.params.subsets:
vq_manifest_path = (
self.dst_manifest_dir
@ -254,6 +256,8 @@ class CodebookIndexExtractor:
cut_ori.codebook_indexes = cut_vq.codebook_indexes
CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path)
logging.info(f"Processed {subset}.")
logging.info(f"Saved to {dst_vq_manifest_path}.")
def merge_vq_manifests(self):
"""
@ -303,7 +307,6 @@ class CodebookIndexExtractor:
os.symlink(ori_manifest_path, dst_manifest_path)
def create_vq_fbank(self):
self.reuse_manifests()
self.merge_vq_manifests()
@cached_property
@ -330,7 +333,7 @@ class CodebookIndexExtractor:
else:
ori_manifest_path = (
self.manifest_dir
/ f"librispeech_cuts_train-{subset}.{self.params.manifest_index}.jsonl.gz"
/ f"librispeech_cuts_train-{subset}.{self.params.manifest_index}.jsonl.gz" # noqa
)
cuts = load_manifest(ori_manifest_path)
@ -343,6 +346,7 @@ class CodebookIndexExtractor:
torch.cuda.empty_cache()
def extract_codebook_indexes(self):
logging.info("Start to extract codebook indexes.")
if self.params.world_size == 1:
self.extract_codebook_indexes_imp()
else: