mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix distillation with HuBERT (#790)
* update vq huggingface url * remove hard lhotse version requirement * resolve ID mismatch * small fixes * Update egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> * update version check Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
parent
a24a1cbfa9
commit
05dfd5e630
@ -35,7 +35,7 @@ stop_stage=4
|
|||||||
# export CUDA_VISIBLE_DEVICES="0"
|
# export CUDA_VISIBLE_DEVICES="0"
|
||||||
#
|
#
|
||||||
# Suppose GPU 2,3,4,5 are available.
|
# Suppose GPU 2,3,4,5 are available.
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
# export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
exp_dir=./pruned_transducer_stateless6/exp
|
exp_dir=./pruned_transducer_stateless6/exp
|
||||||
mkdir -p $exp_dir
|
mkdir -p $exp_dir
|
||||||
@ -49,7 +49,7 @@ full_libri=False
|
|||||||
# "True" -> stage 0 and stage 1 would be skipped,
|
# "True" -> stage 0 and stage 1 would be skipped,
|
||||||
# and directly download the extracted codebook indexes for distillation
|
# and directly download the extracted codebook indexes for distillation
|
||||||
# "False" -> start from scratch
|
# "False" -> start from scratch
|
||||||
use_extracted_codebook=False
|
use_extracted_codebook=True
|
||||||
|
|
||||||
# teacher_model_id can be one of
|
# teacher_model_id can be one of
|
||||||
# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use.
|
# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use.
|
||||||
@ -155,8 +155,14 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
fi
|
fi
|
||||||
log "Downloading extracted codebook indexes to $codebook_download_dir"
|
log "Downloading extracted codebook indexes to $codebook_download_dir"
|
||||||
# Make sure you have git-lfs installed (https://git-lfs.github.com)
|
# Make sure you have git-lfs installed (https://git-lfs.github.com)
|
||||||
|
# The codebook indexes are generated using lhotse 1.11.0, to avoid
|
||||||
|
# potential issues, we recommend you to use lhotse version >= 1.11.0
|
||||||
|
lhotse_version=$(python3 -c "import lhotse; from packaging import version; print(version.parse(lhotse.version.__version__)>=version.parse('1.11.0'))")
|
||||||
|
if [ "$lhotse_version" == "False" ]; then
|
||||||
|
log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch."
|
||||||
|
fi
|
||||||
git lfs install
|
git lfs install
|
||||||
git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
|
git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
|
||||||
|
|
||||||
mkdir -p data/vq_fbank
|
mkdir -p data/vq_fbank
|
||||||
mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/
|
mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/
|
||||||
|
@ -244,10 +244,36 @@ class CodebookIndexExtractor:
|
|||||||
)
|
)
|
||||||
cuts_vq = load_manifest(vq_manifest_path)
|
cuts_vq = load_manifest(vq_manifest_path)
|
||||||
cuts_ori = load_manifest(ori_manifest_path)
|
cuts_ori = load_manifest(ori_manifest_path)
|
||||||
cuts_vq = cuts_vq.sort_like(cuts_ori)
|
assert len(cuts_vq) == len(cuts_ori), "Cuts should have the same length!"
|
||||||
for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)):
|
|
||||||
assert cut_vq.id == cut_ori.id
|
if set(cuts_vq.ids) == set(cuts_ori.ids):
|
||||||
cut_ori.codebook_indexes = cut_vq.codebook_indexes
|
# IDs match exactly
|
||||||
|
cuts_vq = cuts_vq.sort_like(cuts_ori)
|
||||||
|
for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)):
|
||||||
|
assert cut_vq.id == cut_ori.id, (cut_vq.id, cut_ori.id)
|
||||||
|
cut_ori.codebook_indexes = cut_vq.codebook_indexes
|
||||||
|
else:
|
||||||
|
# in case of ID mismatch, remap them
|
||||||
|
# get the mapping between audio and cut ID
|
||||||
|
logging
|
||||||
|
ori_id_map = {}
|
||||||
|
for id in cuts_ori.ids:
|
||||||
|
# some text normalization
|
||||||
|
if "sp" in id:
|
||||||
|
clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1]
|
||||||
|
else:
|
||||||
|
clean_id = "-".join(id.split("-")[:3])
|
||||||
|
ori_id_map[clean_id] = id
|
||||||
|
|
||||||
|
for id in cuts_vq.ids:
|
||||||
|
if "sp" in id:
|
||||||
|
clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1]
|
||||||
|
else:
|
||||||
|
clean_id = "-".join(id.split("-")[:3])
|
||||||
|
assert clean_id in ori_id_map, clean_id
|
||||||
|
cuts_ori[ori_id_map[clean_id]].codebook_indexes = cuts_vq[
|
||||||
|
id
|
||||||
|
].codebook_indexes
|
||||||
|
|
||||||
CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path)
|
CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path)
|
||||||
logging.info(f"Processed {subset}.")
|
logging.info(f"Processed {subset}.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user