Fix distillation with HuBERT (#790)

* update vq huggingface url

* remove hard lhotse version requirement

* resolve ID mismatch

* small fixes


* Update egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* update version check

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
marcoyang1998 2022-12-27 15:26:11 +08:00 committed by GitHub
parent a24a1cbfa9
commit 05dfd5e630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 7 deletions

View File

@ -35,7 +35,7 @@ stop_stage=4
# export CUDA_VISIBLE_DEVICES="0"
#
# Suppose GPU 2,3,4,5 are available.
export CUDA_VISIBLE_DEVICES="0,1,2,3"
# export CUDA_VISIBLE_DEVICES="0,1,2,3"
exp_dir=./pruned_transducer_stateless6/exp
mkdir -p $exp_dir
@ -49,7 +49,7 @@ full_libri=False
# "True" -> stage 0 and stage 1 would be skipped,
# and directly download the extracted codebook indexes for distillation
# "False" -> start from scratch
use_extracted_codebook=False
use_extracted_codebook=True
# teacher_model_id can be one of
# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use.
@ -155,8 +155,14 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi
log "Downloading extracted codebook indexes to $codebook_download_dir"
# Make sure you have git-lfs installed (https://git-lfs.github.com)
# The codebook indexes are generated using lhotse 1.11.0, to avoid
# potential issues, we recommend you to use lhotse version >= 1.11.0
lhotse_version=$(python3 -c "import lhotse; from packaging import version; print(version.parse(lhotse.version.__version__)>=version.parse('1.11.0'))")
if [ "$lhotse_version" == "False" ]; then
log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch."
fi
git lfs install
git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
mkdir -p data/vq_fbank
mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/

View File

@ -244,10 +244,36 @@ class CodebookIndexExtractor:
)
cuts_vq = load_manifest(vq_manifest_path)
cuts_ori = load_manifest(ori_manifest_path)
cuts_vq = cuts_vq.sort_like(cuts_ori)
for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)):
assert cut_vq.id == cut_ori.id
cut_ori.codebook_indexes = cut_vq.codebook_indexes
assert len(cuts_vq) == len(cuts_ori), "Cuts should have the same length!"
if set(cuts_vq.ids) == set(cuts_ori.ids):
# IDs match exactly
cuts_vq = cuts_vq.sort_like(cuts_ori)
for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)):
assert cut_vq.id == cut_ori.id, (cut_vq.id, cut_ori.id)
cut_ori.codebook_indexes = cut_vq.codebook_indexes
else:
# in case of ID mismatch, remap them
# get the mapping between audio and cut ID
logging
ori_id_map = {}
for id in cuts_ori.ids:
# some text normalization
if "sp" in id:
clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1]
else:
clean_id = "-".join(id.split("-")[:3])
ori_id_map[clean_id] = id
for id in cuts_vq.ids:
if "sp" in id:
clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1]
else:
clean_id = "-".join(id.split("-")[:3])
assert clean_id in ori_id_map, clean_id
cuts_ori[ori_id_map[clean_id]].codebook_indexes = cuts_vq[
id
].codebook_indexes
CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path)
logging.info(f"Processed {subset}.")