mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix distillation with HuBERT (#790)
* update vq huggingface url * remove hard lhotse version requirement * resolve ID mismatch * small fixes * Update egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> * update version check Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
parent
a24a1cbfa9
commit
05dfd5e630
@ -35,7 +35,7 @@ stop_stage=4
|
||||
# export CUDA_VISIBLE_DEVICES="0"
|
||||
#
|
||||
# Suppose GPU 2,3,4,5 are available.
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
# export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
exp_dir=./pruned_transducer_stateless6/exp
|
||||
mkdir -p $exp_dir
|
||||
@ -49,7 +49,7 @@ full_libri=False
|
||||
# "True" -> stage 0 and stage 1 would be skipped,
|
||||
# and directly download the extracted codebook indexes for distillation
|
||||
# "False" -> start from scratch
|
||||
use_extracted_codebook=False
|
||||
use_extracted_codebook=True
|
||||
|
||||
# teacher_model_id can be one of
|
||||
# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use.
|
||||
@ -155,8 +155,14 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
fi
|
||||
log "Downloading extracted codebook indexes to $codebook_download_dir"
|
||||
# Make sure you have git-lfs installed (https://git-lfs.github.com)
|
||||
# The codebook indexes are generated using lhotse 1.11.0, to avoid
|
||||
# potential issues, we recommend you to use lhotse version >= 1.11.0
|
||||
lhotse_version=$(python3 -c "import lhotse; from packaging import version; print(version.parse(lhotse.version.__version__)>=version.parse('1.11.0'))")
|
||||
if [ "$lhotse_version" == "False" ]; then
|
||||
log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch."
|
||||
fi
|
||||
git lfs install
|
||||
git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
|
||||
git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
|
||||
|
||||
mkdir -p data/vq_fbank
|
||||
mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/
|
||||
|
@ -244,10 +244,36 @@ class CodebookIndexExtractor:
|
||||
)
|
||||
cuts_vq = load_manifest(vq_manifest_path)
|
||||
cuts_ori = load_manifest(ori_manifest_path)
|
||||
cuts_vq = cuts_vq.sort_like(cuts_ori)
|
||||
for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)):
|
||||
assert cut_vq.id == cut_ori.id
|
||||
cut_ori.codebook_indexes = cut_vq.codebook_indexes
|
||||
assert len(cuts_vq) == len(cuts_ori), "Cuts should have the same length!"
|
||||
|
||||
if set(cuts_vq.ids) == set(cuts_ori.ids):
|
||||
# IDs match exactly
|
||||
cuts_vq = cuts_vq.sort_like(cuts_ori)
|
||||
for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)):
|
||||
assert cut_vq.id == cut_ori.id, (cut_vq.id, cut_ori.id)
|
||||
cut_ori.codebook_indexes = cut_vq.codebook_indexes
|
||||
else:
|
||||
# in case of ID mismatch, remap them
|
||||
# get the mapping between audio and cut ID
|
||||
logging
|
||||
ori_id_map = {}
|
||||
for id in cuts_ori.ids:
|
||||
# some text normalization
|
||||
if "sp" in id:
|
||||
clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1]
|
||||
else:
|
||||
clean_id = "-".join(id.split("-")[:3])
|
||||
ori_id_map[clean_id] = id
|
||||
|
||||
for id in cuts_vq.ids:
|
||||
if "sp" in id:
|
||||
clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1]
|
||||
else:
|
||||
clean_id = "-".join(id.split("-")[:3])
|
||||
assert clean_id in ori_id_map, clean_id
|
||||
cuts_ori[ori_id_map[clean_id]].codebook_indexes = cuts_vq[
|
||||
id
|
||||
].codebook_indexes
|
||||
|
||||
CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path)
|
||||
logging.info(f"Processed {subset}.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user