diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index d5d3008aa..a38cf590c 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -43,7 +43,7 @@ mkdir -p $exp_dir # full_libri can be "True" or "False" # "True" -> use full librispeech dataset for distillation # "False" -> use train-clean-100 subset for distillation -full_libri=False +full_libri=True # use_extracted_codebook can be "True" or "False" # "True" -> stage 0 and stage 1 would be skipped, @@ -145,8 +145,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Currently we only uploaded codebook indexes from teacher model hubert_xtralarge_ll60k_finetune_ls960" exit 1 fi + # The codebook indexes to be downloaded are generated using the following setup: + embedding_layer=36 + num_codebooks=8 + mkdir -p $exp_dir/vq - codebook_dir=$exp_dir/vq/$teacher_model_id + codebook_dir=$exp_dir/vq/${teacher_model_id}_layer${embedding_layer}_cb${num_codebooks} mkdir -p codebook_dir codebook_download_dir=$exp_dir/download_codebook if [ -d $codebook_download_dir ]; then @@ -164,8 +168,9 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then git lfs install git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir - mkdir -p data/vq_fbank - mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/ + vq_fbank=data/vq_fbank_layer${embedding_layer}_cb${num_codebooks}/ + mkdir -p $vq_fbank + mv $codebook_download_dir/*.jsonl.gz $vq_fbank mkdir -p $codebook_dir/splits4 mv $codebook_download_dir/*.h5 $codebook_dir/splits4/ log "Remove $codebook_download_dir" @@ -181,6 +186,15 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --max-duration 100 \ --teacher-model-id $teacher_model_id \ --use-extracted-codebook $use_extracted_codebook + + if [ "$full_libri" == "True" ]; then + # Merge the 3 subsets and create a full one + rm ${vq_fbank}/librispeech_cuts_train-all-shuf.jsonl.gz + cat <(gunzip -c ${vq_fbank}/librispeech_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c ${vq_fbank}/librispeech_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c ${vq_fbank}/librispeech_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > ${vq_fbank}/librispeech_cuts_train-all-shuf.jsonl.gz + fi fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index bf072d865..14ff86f23 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,7 +68,10 @@ class CodebookIndexExtractor: def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" + self.vq_dir = ( + self.params.exp_dir + / f"vq/{self.params.teacher_model_id}_layer{self.params.embedding_layer}_cb{self.params.num_codebooks}/" + ) self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -79,7 +82,10 @@ class CodebookIndexExtractor: # It's doesn't matter whether ori_manifest_dir is str or Path. # Set it to Path to be consistent. self.ori_manifest_dir = Path("./data/fbank/") - self.dst_manifest_dir = Path("./data/vq_fbank/") + self.dst_manifest_dir = Path( + f"./data/vq_fbank_layer" + + f"{self.params.embedding_layer}_cb{self.params.num_codebooks}/" + ) self.dst_manifest_dir.mkdir(parents=True, exist_ok=True) @@ -284,7 +290,10 @@ class CodebookIndexExtractor: Merge generated vq included manfiests and storage to self.dst_manifest_dir. """ for subset in self.params.subsets: - vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" + vq_manifests = ( + f"{self.manifest_dir}/" + + f"with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" + ) dst_vq_manifest = ( self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" )