diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh old mode 100644 new mode 100755 index e18ba8f55..3d4c4856a --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash +# # A short introduction about distillation framework. # # A typical traditional distillation method is @@ -14,15 +16,15 @@ # teacher embeddings. # 3. a middle layer 6(1-based) out of total 6 layers is used to extract # student embeddings. - -# This is an example to do distillation with librispeech clean-100 subset. -# run with command: -# bash distillation_with_hubert.sh [0|1|2|3|4] # -# For example command -# bash distillation_with_hubert.sh 0 -# will download hubert model. -stage=$1 +# To directly download the extracted codebook indexes for model distillation, you can +# set stage=2, stop_stage=4, use_extracted_codebook=True +# +# To start from scratch, you can +# set stage=0, stop_stage=4, use_extracted_codebook=False + +stage=0 +stop_stage=4 # Set the GPUs available. # This script requires at least one GPU. @@ -33,10 +35,35 @@ stage=$1 # export CUDA_VISIBLE_DEVICES="0" # # Suppose GPU 2,3,4,5 are available. -export CUDA_VISIBLE_DEVICES="2,3,4,5" +export CUDA_VISIBLE_DEVICES="0,1,2,3" +exp_dir=./pruned_transducer_stateless6/exp +mkdir -p $exp_dir -if [ $stage -eq 0 ]; then +# full_libri can be "True" or "False" +# "True" -> use full librispeech dataset for distillation +# "False" -> use train-clean-100 subset for distillation +full_libri=False + +# use_extracted_codebook can be "True" or "False" +# "True" -> stage 0 and stage 1 would be skipped, +# and directly download the extracted codebook indexes for distillation +# "False" -> start from scratch +use_extracted_codebook=False + +# teacher_model_id can be one of +# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use. +# "hubert_xtralarge_ll60k" -> pretrained model without fintuing +teacher_model_id=hubert_xtralarge_ll60k_finetune_ls960 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" == "True" ]; then + log "Stage 0: Download HuBERT model" # Preparation stage. # Install fairseq according to: @@ -45,7 +72,7 @@ if [ $stage -eq 0 ]; then # commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used. has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)") if [ $has_fairseq == 'False' ]; then - echo "Please install fairseq before running following stages" + log "Please install fairseq before running following stages" exit 1 fi @@ -56,42 +83,41 @@ if [ $stage -eq 0 ]; then has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)") if [ $has_quantization == 'False' ]; then - echo "Please install quantization before running following stages" + log "Please install quantization before running following stages" exit 1 fi - echo "Download hubert model." + log "Download HuBERT model." # Parameters about model. - exp_dir=./pruned_transducer_stateless6/exp/ - model_id=hubert_xtralarge_ll60k_finetune_ls960 hubert_model_dir=${exp_dir}/hubert_models - hubert_model=${hubert_model_dir}/${model_id}.pt + hubert_model=${hubert_model_dir}/${teacher_model_id}.pt mkdir -p ${hubert_model_dir} # For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert if [ -f ${hubert_model} ]; then - echo "hubert model alread exists." + log "HuBERT model alread exists." else - wget -c https://dl.fbaipublicfiles.com/hubert/${model_id} -P ${hubert_model} + wget -c https://dl.fbaipublicfiles.com/hubert/${teacher_model_id}.pt -P ${hubert_model_dir} wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir} fi fi if [ ! -d ./data/fbank ]; then - echo "This script assumes ./data/fbank is already generated by prepare.sh" + log "This script assumes ./data/fbank is already generated by prepare.sh" exit 1 fi -if [ $stage -eq 1 ]; then +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] && [ ! "$use_extracted_codebook" == "True" ]; then + log "Stage 1: Verify that the downloaded HuBERT model is correct." # This stage is not directly used by codebook indexes extraction. # It is a method to "prove" that the downloaed hubert model # is inferenced in an correct way if WERs look like normal. # Expect WERs: # [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ] # [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ] - ./pruned_transducer_stateless6/hubert_decode.py + ./pruned_transducer_stateless6/hubert_decode.py --exp-dir $exp_dir fi -if [ $stage -eq 2 ]; then +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # Analysis of disk usage: # With num_codebooks==8, each teacher embedding is quantized into # a sequence of eight 8-bit integers, i.e. only eight bytes are needed. @@ -113,25 +139,61 @@ if [ $stage -eq 2 ]; then # During quantizer's training data(teacher embedding) and it's training, # only the first ONE GPU is used. # During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used. + + if [ "$use_extracted_codebook" == "True" ]; then + if [ ! "$teacher_model_id" == "hubert_xtralarge_ll60k_finetune_ls960" ]; then + log "Currently we only uploaded codebook indexes from teacher model hubert_xtralarge_ll60k_finetune_ls960" + exit 1 + fi + mkdir -p $exp_dir/vq + codebook_dir=$exp_dir/vq/$teacher_model_id + mkdir -p codebook_dir + codebook_download_dir=$exp_dir/download_codebook + if [ -d $codebook_download_dir ]; then + log "$codebook_download_dir exists, you should remove it first." + exit 1 + fi + log "Downloading extracted codebook indexes to $codebook_download_dir" + # Make sure you have git-lfs installed (https://git-lfs.github.com) + git lfs install + git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir + + mkdir -p data/vq_fbank + mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/ + mkdir -p $codebook_dir/splits4 + mv $codebook_download_dir/*.h5 $codebook_dir/splits4/ + log "Remove $codebook_download_dir" + rm -rf $codebook_download_dir + fi + ./pruned_transducer_stateless6/extract_codebook_index.py \ - --full-libri False + --full-libri $full_libri \ + --exp-dir $exp_dir \ + --embedding-layer 36 \ + --num-utts 1000 \ + --num-codebooks 8 \ + --max-duration 100 \ + --teacher-model-id $teacher_model_id \ + --use-extracted-codebook $use_extracted_codebook fi -if [ $stage -eq 3 ]; then +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then # Example training script. # Note: it's better to set spec-aug-time-warpi-factor=-1 WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}') ./pruned_transducer_stateless6/train.py \ --manifest-dir ./data/vq_fbank \ --master-port 12359 \ - --full-libri False \ + --full-libri $full_libri \ --spec-aug-time-warp-factor -1 \ --max-duration 300 \ --world-size ${WORLD_SIZE} \ - --num-epochs 20 + --num-epochs 20 \ + --exp-dir $exp_dir \ + --enable-distillation True fi -if [ $stage -eq 4 ]; then +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then # Results should be similar to: # errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67 # errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60 @@ -140,5 +202,6 @@ if [ $stage -eq 4 ]; then --epoch 20 \ --avg 10 \ --max-duration 200 \ - --exp-dir ./pruned_transducer_stateless6/exp + --exp-dir $exp_dir \ + --enable-distillation True fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 4739a6526..701cad73c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -128,7 +128,7 @@ def get_parser(): parser.add_argument( "--use-averaged-model", type=str2bool, - default=False, + default=True, help="Whether to load averaged model. Currently it only supports " "using --epoch. If True, it would decode with the averaged model " "over the epoch range from `epoch-avg` (excluded) to `epoch`." @@ -143,6 +143,13 @@ def get_parser(): help="The experiment dir", ) + parser.add_argument( + "--enable-distillation", + type=str2bool, + default=True, + help="Whether to eanble distillation.", + ) + parser.add_argument( "--bpe-model", type=str, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index c5c172ff2..21409287c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -24,7 +24,7 @@ import torch from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned -from icefall.utils import AttributeDict +from icefall.utils import AttributeDict, str2bool def get_parser(): @@ -38,6 +38,13 @@ def get_parser(): help="The experiment dir", ) + parser.add_argument( + "--use-extracted-codebook", + type=str2bool, + default=False, + help="Whether to use the extracted codebook indexes.", + ) + return parser @@ -71,9 +78,13 @@ def main(): params.world_size = world_size extractor = CodebookIndexExtractor(params=params) - extractor.extract_and_save_embedding() - extractor.train_quantizer() - extractor.extract_codebook_indexes() + if not params.use_extracted_codebook: + extractor.extract_and_save_embedding() + extractor.train_quantizer() + extractor.extract_codebook_indexes() + + extractor.reuse_manifests() + extractor.join_manifests() if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 9e9fc1440..b904e1e59 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -41,7 +41,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --full-libri 1 \ --max-duration 550 -# For distiallation with codebook_indexes: +# For distillation with codebook_indexes: ./pruned_transducer_stateless6/train.py \ --manifest-dir ./data/vq_fbank \ @@ -300,6 +300,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--enable-distillation", + type=str2bool, + default=True, + help="Whether to eanble distillation.", + ) + return parser @@ -372,7 +379,6 @@ def get_params() -> AttributeDict: "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), # parameters for distillation with codebook indexes. - "enable_distiallation": True, "distillation_layer": 5, # 0-based index # Since output rate of hubert is 50, while that of encoder is 8, # two successive codebook_index are concatenated together. @@ -394,7 +400,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, middle_output_layer=params.distillation_layer - if params.enable_distiallation + if params.enable_distillation else None, ) return encoder @@ -433,9 +439,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, - num_codebooks=params.num_codebooks - if params.enable_distiallation - else 0, + num_codebooks=params.num_codebooks if params.enable_distillation else 0, ) return model @@ -615,7 +619,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) info = MetricsTracker() - if is_training and params.enable_distiallation: + if is_training and params.enable_distillation: codebook_indexes, _ = extract_codebook_indexes(batch) codebook_indexes = codebook_indexes.to(device) else: @@ -645,7 +649,7 @@ def compute_loss( params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss ) - if is_training and params.enable_distiallation: + if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -661,7 +665,7 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() - if is_training and params.enable_distiallation: + if is_training and params.enable_distillation: info["codebook_loss"] = codebook_loss.detach().cpu().item() return loss, info diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index c4935f921..e3dcd039b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -37,6 +37,7 @@ from icefall.utils import ( setup_logger, ) from lhotse import CutSet, load_manifest +from lhotse.cut import MonoCut from lhotse.features.io import NumpyHdf5Writer @@ -62,16 +63,15 @@ class CodebookIndexExtractor: setup_logger(f"{self.vq_dir}/log-vq_extraction") def init_dirs(self): - # vq_dir is the root dir for quantizer: - # training data/ quantizer / extracted codebook indexes + # vq_dir is the root dir for quantization, containing: + # training data, trained quantizer, and extracted codebook indexes self.vq_dir = ( self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" ) self.vq_dir.mkdir(parents=True, exist_ok=True) - # manifest_dir for : - # splited original manifests, - # extracted codebook indexes and their related manifests + # manifest_dir contains: + # splited original manifests, extracted codebook indexes with related manifests # noqa self.manifest_dir = self.vq_dir / f"splits{self.params.world_size}" self.manifest_dir.mkdir(parents=True, exist_ok=True) @@ -135,6 +135,7 @@ class CodebookIndexExtractor: logging.warn(warn_message) return + logging.info("Start to extract embeddings for training the quantizer.") total_cuts = 0 with NumpyHdf5Writer(self.embedding_file_path) as writer: for batch_idx, batch in enumerate(self.quantizer_train_dl): @@ -187,14 +188,15 @@ class CodebookIndexExtractor: return assert self.embedding_file_path.exists() + logging.info("Start to train quantizer.") trainer = quantization.QuantizerTrainer( dim=self.params.embedding_dim, bytes_per_frame=self.params.num_codebooks, device=self.params.device, ) train, valid = quantization.read_hdf5_data(self.embedding_file_path) - B = 512 # Minibatch size, this is very arbitrary, it's close to what we used - # when we tuned this method. + B = 512 # Minibatch size, this is very arbitrary, + # it's close to what we used when we tuned this method. def minibatch_generator(data: torch.Tensor, repeat: bool): assert 3 * B < data.shape[0] @@ -222,18 +224,50 @@ class CodebookIndexExtractor: """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = f"./data/fbank/cuts_train-{subset}.json.gz" + ori_manifest = ( + f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ) split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") + def join_manifests(self): + """ + Join the vq manifest to the original manifest according to cut id. + """ + logging.info("Start to join manifest files.") + for subset in self.params.subsets: + vq_manifest_path = ( + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + ) + ori_manifest_path = ( + self.ori_manifest_dir + / f"librispeech_cuts_train-{subset}.jsonl.gz" + ) + dst_vq_manifest_path = ( + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}.jsonl.gz" + ) + cuts_vq = load_manifest(vq_manifest_path) + cuts_ori = load_manifest(ori_manifest_path) + cuts_vq = cuts_vq.sort_like(cuts_ori) + for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)): + assert cut_vq.id == cut_ori.id + cut_ori.codebook_indexes = cut_vq.codebook_indexes + + CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path) + logging.info(f"Processed {subset}.") + logging.info(f"Saved to {dst_vq_manifest_path}.") + def merge_vq_manifests(self): """ Merge generated vq included manfiests and storage to self.dst_manifest_dir. """ for subset in self.params.subsets: - vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-cuts_train-{subset}*.json.gz" + vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" dst_vq_manifest = ( - self.dst_manifest_dir / f"cuts_train-{subset}.json.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -273,7 +307,6 @@ class CodebookIndexExtractor: os.symlink(ori_manifest_path, dst_manifest_path) def create_vq_fbank(self): - self.reuse_manifests() self.merge_vq_manifests() @cached_property @@ -294,11 +327,13 @@ class CodebookIndexExtractor: def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = f"./data/fbank/cuts_train-{subset}.json.gz" + ori_manifest_path = ( + f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ) else: ori_manifest_path = ( self.manifest_dir - / f"cuts_train-{subset}.{self.params.manifest_index}.json.gz" + / f"librispeech_cuts_train-{subset}.{self.params.manifest_index}.jsonl.gz" # noqa ) cuts = load_manifest(ori_manifest_path) @@ -311,6 +346,7 @@ class CodebookIndexExtractor: torch.cuda.empty_cache() def extract_codebook_indexes(self): + logging.info("Start to extract codebook indexes.") if self.params.world_size == 1: self.extract_codebook_indexes_imp() else: @@ -333,7 +369,7 @@ class CodebookIndexExtractor: def extract_codebook_indexes_imp(self): for subset in self.params.subsets: num_cuts = 0 - cuts = [] + new_cuts = [] if self.params.world_size == 1: manifest_file_id = f"{subset}" else: @@ -356,15 +392,23 @@ class CodebookIndexExtractor: assert len(cut_list) == codebook_indexes.shape[0] assert all(c.start == 0 for c in supervisions["cut"]) + new_cut_list = [] for idx, cut in enumerate(cut_list): - cut.codebook_indexes = writer.store_array( + new_cut = MonoCut( + id=cut.id, + start=cut.start, + duration=cut.duration, + channel=cut.channel, + ) + new_cut.codebook_indexes = writer.store_array( key=cut.id, value=codebook_indexes[idx][: num_frames[idx]], frame_shift=0.02, temporal_dim=0, start=0, ) - cuts += cut_list + new_cut_list.append(new_cut) + new_cuts += new_cut_list num_cuts += len(cut_list) message = f"Processed {num_cuts} cuts from {subset}" if self.params.world_size > 1: @@ -373,9 +417,9 @@ class CodebookIndexExtractor: json_file_path = ( self.manifest_dir - / f"with_codebook_indexes-cuts_train-{manifest_file_id}.json.gz" + / f"with_codebook_indexes-librispeech-cuts_train-{manifest_file_id}.jsonl.gz" # noqa ) - CutSet.from_cuts(cuts).to_json(json_file_path) + CutSet.from_cuts(new_cuts).to_jsonl(json_file_path) @torch.no_grad()