support using extracted codebook indexes

This commit is contained in:
yaozengwei 2022-06-16 21:31:26 +08:00
parent 1ed96824a0
commit 496abc30c0
3 changed files with 111 additions and 35 deletions

101
egs/librispeech/ASR/distillation_with_hubert.sh Normal file → Executable file
View File

@ -1,3 +1,5 @@
#!/usr/bin/env bash
#
# A short introduction about distillation framework. # A short introduction about distillation framework.
# #
# A typical traditional distillation method is # A typical traditional distillation method is
@ -22,7 +24,11 @@
# For example command # For example command
# bash distillation_with_hubert.sh 0 # bash distillation_with_hubert.sh 0
# will download hubert model. # will download hubert model.
stage=$1
set -x
stage=2
stop_stage=3
# Set the GPUs available. # Set the GPUs available.
# This script requires at least one GPU. # This script requires at least one GPU.
@ -33,10 +39,32 @@ stage=$1
# export CUDA_VISIBLE_DEVICES="0" # export CUDA_VISIBLE_DEVICES="0"
# #
# Suppose GPU 2,3,4,5 are available. # Suppose GPU 2,3,4,5 are available.
export CUDA_VISIBLE_DEVICES="2,3,4,5" export CUDA_VISIBLE_DEVICES="0,1,2,3"
exp_dir=./pruned_transducer_stateless6/exp
mkdir -p $exp_dir
if [ $stage -eq 0 ]; then # full_libri can be "True" or "False"
# If "True", the distillation will use full librispeech dataset.
full_libri=False
# use_extracted_codebook can be "True" or "False"
# If "True", stage 0 and stage 1 would be skipped
use_extracted_codebook=False
# teacher_model_id can be one of
# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use.
# "hubert_xtralarge_ll60k.pt" -> pretrained model without fintuing
teacher_model_id=hubert_xtralarge_ll60k_finetune_ls960
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" == "True" ]; then
log "Stage 0: Download HuBERT model"
# Preparation stage. # Preparation stage.
# Install fairseq according to: # Install fairseq according to:
@ -45,7 +73,7 @@ if [ $stage -eq 0 ]; then
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used. # commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used.
has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)") has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)")
if [ $has_fairseq == 'False' ]; then if [ $has_fairseq == 'False' ]; then
echo "Please install fairseq before running following stages" log "Please install fairseq before running following stages"
exit 1 exit 1
fi fi
@ -56,42 +84,41 @@ if [ $stage -eq 0 ]; then
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)") has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
if [ $has_quantization == 'False' ]; then if [ $has_quantization == 'False' ]; then
echo "Please install quantization before running following stages" log "Please install quantization before running following stages"
exit 1 exit 1
fi fi
echo "Download hubert model." log "Download HuBERT model."
# Parameters about model. # Parameters about model.
exp_dir=./pruned_transducer_stateless6/exp/
model_id=hubert_xtralarge_ll60k_finetune_ls960
hubert_model_dir=${exp_dir}/hubert_models hubert_model_dir=${exp_dir}/hubert_models
hubert_model=${hubert_model_dir}/${model_id}.pt hubert_model=${hubert_model_dir}/${teacher_model_id}.pt
mkdir -p ${hubert_model_dir} mkdir -p ${hubert_model_dir}
# For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert # For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert
if [ -f ${hubert_model} ]; then if [ -f ${hubert_model} ]; then
echo "hubert model alread exists." log "HuBERT model alread exists."
else else
wget -c https://dl.fbaipublicfiles.com/hubert/${model_id} -P ${hubert_model} wget -c https://dl.fbaipublicfiles.com/hubert/${teacher_model_id}.pt -P ${hubert_model_dir}
wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir} wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir}
fi fi
fi fi
if [ ! -d ./data/fbank ]; then if [ ! -d ./data/fbank ]; then
echo "This script assumes ./data/fbank is already generated by prepare.sh" log "This script assumes ./data/fbank is already generated by prepare.sh"
exit 1 exit 1
fi fi
if [ $stage -eq 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] && [ ! "$use_extracted_codebook" == "True" ]; then
log "Stage 1: Verify that the downloaded HuBERT model is correct."
# This stage is not directly used by codebook indexes extraction. # This stage is not directly used by codebook indexes extraction.
# It is a method to "prove" that the downloaed hubert model # It is a method to "prove" that the downloaed hubert model
# is inferenced in an correct way if WERs look like normal. # is inferenced in an correct way if WERs look like normal.
# Expect WERs: # Expect WERs:
# [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ] # [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ]
# [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ] # [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ]
./pruned_transducer_stateless6/hubert_decode.py ./pruned_transducer_stateless6/hubert_decode.py --exp-dir $exp_dir
fi fi
if [ $stage -eq 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# Analysis of disk usage: # Analysis of disk usage:
# With num_codebooks==8, each teacher embedding is quantized into # With num_codebooks==8, each teacher embedding is quantized into
# a sequence of eight 8-bit integers, i.e. only eight bytes are needed. # a sequence of eight 8-bit integers, i.e. only eight bytes are needed.
@ -113,25 +140,59 @@ if [ $stage -eq 2 ]; then
# During quantizer's training data(teacher embedding) and it's training, # During quantizer's training data(teacher embedding) and it's training,
# only the first ONE GPU is used. # only the first ONE GPU is used.
# During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used. # During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used.
./pruned_transducer_stateless6/extract_codebook_index.py \
--full-libri False if [ "$use_extracted_codebook" == "True" ]; then
if [ ! "$teacher_model_id" == "hubert_xtralarge_ll60k_finetune_ls960" ]; then
log "Currently we only uploaded codebook indexes from teacher model hubert_xtralarge_ll60k_finetune_ls960"
exit 1
fi
mkdir -p $exp_dir/vq
codebook_dir=$exp_dir/vq/$teacher_model_id
mkdir -p codebook_dir
codebook_download_dir=$exp_dir/download_codebook
if [ -d $codebook_download_dir ]; then
log "$download_codebook exists, you should remove it first."
exit 1
fi
log "Downloading extracted codebook indexes to $codebook_download_dir"
# Make sure you have git-lfs installed (https://git-lfs.github.com)
git lfs install
git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
mkdir -p data/vq_fbank
mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/
mkdir -p $codebook_dir/splits4
mv $codebook_download_dir/*.h5 $codebook_dir/splits4/
log "Remove $codebook_download_dir"
rm -rf $codebook_download_dir
fi fi
if [ $stage -eq 3 ]; then ./pruned_transducer_stateless6/extract_codebook_index.py \
--full-libri $full_libri \
--exp-dir $exp_dir \
--embedding-layer 36 \
--num-utts 1000 \
--num-codebooks 8 \
--max-duration 100 \
--teacher-model-id $teacher_model_id \
--use-extracted-codebook $use_extracted_codebook
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
# Example training script. # Example training script.
# Note: it's better to set spec-aug-time-warpi-factor=-1 # Note: it's better to set spec-aug-time-warpi-factor=-1
WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}') WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}')
./pruned_transducer_stateless6/train.py \ ./pruned_transducer_stateless6/train.py \
--manifest-dir ./data/vq_fbank \ --manifest-dir ./data/vq_fbank \
--master-port 12359 \ --master-port 12359 \
--full-libri False \ --full-libri $full_libri \
--spec-aug-time-warp-factor -1 \ --spec-aug-time-warp-factor -1 \
--max-duration 300 \ --max-duration 300 \
--world-size ${WORLD_SIZE} \ --world-size ${WORLD_SIZE} \
--num-epochs 20 --num-epochs 20
fi fi
if [ $stage -eq 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
# Results should be similar to: # Results should be similar to:
# errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67 # errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67
# errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60 # errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60

View File

@ -24,7 +24,7 @@ import torch
from vq_utils import CodebookIndexExtractor from vq_utils import CodebookIndexExtractor
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from hubert_xlarge import HubertXlargeFineTuned from hubert_xlarge import HubertXlargeFineTuned
from icefall.utils import AttributeDict from icefall.utils import AttributeDict, str2bool
def get_parser(): def get_parser():
@ -38,6 +38,13 @@ def get_parser():
help="The experiment dir", help="The experiment dir",
) )
parser.add_argument(
"--use-extracted-codebook",
type=str2bool,
default=False,
help="Whether to use the extracted codebook indexes.",
)
return parser return parser
@ -71,10 +78,14 @@ def main():
params.world_size = world_size params.world_size = world_size
extractor = CodebookIndexExtractor(params=params) extractor = CodebookIndexExtractor(params=params)
if not params.use_extracted_codebook:
extractor.extract_and_save_embedding() extractor.extract_and_save_embedding()
extractor.train_quantizer() extractor.train_quantizer()
extractor.extract_codebook_indexes() extractor.extract_codebook_indexes()
extractor.reuse_manifests()
extractor.join_manifests()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -63,17 +63,15 @@ class CodebookIndexExtractor:
setup_logger(f"{self.vq_dir}/log-vq_extraction") setup_logger(f"{self.vq_dir}/log-vq_extraction")
def init_dirs(self): def init_dirs(self):
# TODO: # vq_dir is the root dir for quantization, containing:
# vq_dir is the root dir for quantizer: # training data, trained quantizer, and extracted codebook indexes
# training data/ quantizer / extracted codebook indexes
self.vq_dir = ( self.vq_dir = (
self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" self.params.exp_dir / f"vq/{self.params.teacher_model_id}/"
) )
self.vq_dir.mkdir(parents=True, exist_ok=True) self.vq_dir.mkdir(parents=True, exist_ok=True)
# manifest_dir for : # manifest_dir contains:
# splited original manifests, # splited original manifests, extracted codebook indexes with related manifests # noqa
# extracted codebook indexes and their related manifests
self.manifest_dir = self.vq_dir / f"splits{self.params.world_size}" self.manifest_dir = self.vq_dir / f"splits{self.params.world_size}"
self.manifest_dir.mkdir(parents=True, exist_ok=True) self.manifest_dir.mkdir(parents=True, exist_ok=True)
@ -137,6 +135,7 @@ class CodebookIndexExtractor:
logging.warn(warn_message) logging.warn(warn_message)
return return
logging.info("Start to extract embeddings for training the quantizer.")
total_cuts = 0 total_cuts = 0
with NumpyHdf5Writer(self.embedding_file_path) as writer: with NumpyHdf5Writer(self.embedding_file_path) as writer:
for batch_idx, batch in enumerate(self.quantizer_train_dl): for batch_idx, batch in enumerate(self.quantizer_train_dl):
@ -189,14 +188,15 @@ class CodebookIndexExtractor:
return return
assert self.embedding_file_path.exists() assert self.embedding_file_path.exists()
logging.info("Start to train quantizer.")
trainer = quantization.QuantizerTrainer( trainer = quantization.QuantizerTrainer(
dim=self.params.embedding_dim, dim=self.params.embedding_dim,
bytes_per_frame=self.params.num_codebooks, bytes_per_frame=self.params.num_codebooks,
device=self.params.device, device=self.params.device,
) )
train, valid = quantization.read_hdf5_data(self.embedding_file_path) train, valid = quantization.read_hdf5_data(self.embedding_file_path)
B = 512 # Minibatch size, this is very arbitrary, it's close to what we used B = 512 # Minibatch size, this is very arbitrary,
# when we tuned this method. # it's close to what we used when we tuned this method.
def minibatch_generator(data: torch.Tensor, repeat: bool): def minibatch_generator(data: torch.Tensor, repeat: bool):
assert 3 * B < data.shape[0] assert 3 * B < data.shape[0]
@ -231,8 +231,10 @@ class CodebookIndexExtractor:
os.system(f"{split_cmd}") os.system(f"{split_cmd}")
def join_manifests(self): def join_manifests(self):
"""TODO:""" """
Join the vq manifest to the original manifest according to cut id.
"""
logging.info("Start to join manifest files.")
for subset in self.params.subsets: for subset in self.params.subsets:
vq_manifest_path = ( vq_manifest_path = (
self.dst_manifest_dir self.dst_manifest_dir
@ -254,6 +256,8 @@ class CodebookIndexExtractor:
cut_ori.codebook_indexes = cut_vq.codebook_indexes cut_ori.codebook_indexes = cut_vq.codebook_indexes
CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path) CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path)
logging.info(f"Processed {subset}.")
logging.info(f"Saved to {dst_vq_manifest_path}.")
def merge_vq_manifests(self): def merge_vq_manifests(self):
""" """
@ -303,7 +307,6 @@ class CodebookIndexExtractor:
os.symlink(ori_manifest_path, dst_manifest_path) os.symlink(ori_manifest_path, dst_manifest_path)
def create_vq_fbank(self): def create_vq_fbank(self):
self.reuse_manifests()
self.merge_vq_manifests() self.merge_vq_manifests()
@cached_property @cached_property
@ -330,7 +333,7 @@ class CodebookIndexExtractor:
else: else:
ori_manifest_path = ( ori_manifest_path = (
self.manifest_dir self.manifest_dir
/ f"librispeech_cuts_train-{subset}.{self.params.manifest_index}.jsonl.gz" / f"librispeech_cuts_train-{subset}.{self.params.manifest_index}.jsonl.gz" # noqa
) )
cuts = load_manifest(ori_manifest_path) cuts = load_manifest(ori_manifest_path)
@ -343,6 +346,7 @@ class CodebookIndexExtractor:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def extract_codebook_indexes(self): def extract_codebook_indexes(self):
logging.info("Start to extract codebook indexes.")
if self.params.world_size == 1: if self.params.world_size == 1:
self.extract_codebook_indexes_imp() self.extract_codebook_indexes_imp()
else: else: