mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Upload extracted codebook indexes (#429)
* save only vq-related info to manifest * support to join manifest files * support using extracted codebook indexes * fix doc * minor fix * add enable-distillation argument option, fix monir typos * fix style * fix typo
This commit is contained in:
parent
91b2765cfd
commit
d3daeaf5cd
119
egs/librispeech/ASR/distillation_with_hubert.sh
Normal file → Executable file
119
egs/librispeech/ASR/distillation_with_hubert.sh
Normal file → Executable file
@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# A short introduction about distillation framework.
|
||||
#
|
||||
# A typical traditional distillation method is
|
||||
@ -14,15 +16,15 @@
|
||||
# teacher embeddings.
|
||||
# 3. a middle layer 6(1-based) out of total 6 layers is used to extract
|
||||
# student embeddings.
|
||||
|
||||
# This is an example to do distillation with librispeech clean-100 subset.
|
||||
# run with command:
|
||||
# bash distillation_with_hubert.sh [0|1|2|3|4]
|
||||
#
|
||||
# For example command
|
||||
# bash distillation_with_hubert.sh 0
|
||||
# will download hubert model.
|
||||
stage=$1
|
||||
# To directly download the extracted codebook indexes for model distillation, you can
|
||||
# set stage=2, stop_stage=4, use_extracted_codebook=True
|
||||
#
|
||||
# To start from scratch, you can
|
||||
# set stage=0, stop_stage=4, use_extracted_codebook=False
|
||||
|
||||
stage=0
|
||||
stop_stage=4
|
||||
|
||||
# Set the GPUs available.
|
||||
# This script requires at least one GPU.
|
||||
@ -33,10 +35,35 @@ stage=$1
|
||||
# export CUDA_VISIBLE_DEVICES="0"
|
||||
#
|
||||
# Suppose GPU 2,3,4,5 are available.
|
||||
export CUDA_VISIBLE_DEVICES="2,3,4,5"
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
exp_dir=./pruned_transducer_stateless6/exp
|
||||
mkdir -p $exp_dir
|
||||
|
||||
if [ $stage -eq 0 ]; then
|
||||
# full_libri can be "True" or "False"
|
||||
# "True" -> use full librispeech dataset for distillation
|
||||
# "False" -> use train-clean-100 subset for distillation
|
||||
full_libri=False
|
||||
|
||||
# use_extracted_codebook can be "True" or "False"
|
||||
# "True" -> stage 0 and stage 1 would be skipped,
|
||||
# and directly download the extracted codebook indexes for distillation
|
||||
# "False" -> start from scratch
|
||||
use_extracted_codebook=False
|
||||
|
||||
# teacher_model_id can be one of
|
||||
# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use.
|
||||
# "hubert_xtralarge_ll60k" -> pretrained model without fintuing
|
||||
teacher_model_id=hubert_xtralarge_ll60k_finetune_ls960
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" == "True" ]; then
|
||||
log "Stage 0: Download HuBERT model"
|
||||
# Preparation stage.
|
||||
|
||||
# Install fairseq according to:
|
||||
@ -45,7 +72,7 @@ if [ $stage -eq 0 ]; then
|
||||
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used.
|
||||
has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)")
|
||||
if [ $has_fairseq == 'False' ]; then
|
||||
echo "Please install fairseq before running following stages"
|
||||
log "Please install fairseq before running following stages"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -56,42 +83,41 @@ if [ $stage -eq 0 ]; then
|
||||
|
||||
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
|
||||
if [ $has_quantization == 'False' ]; then
|
||||
echo "Please install quantization before running following stages"
|
||||
log "Please install quantization before running following stages"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Download hubert model."
|
||||
log "Download HuBERT model."
|
||||
# Parameters about model.
|
||||
exp_dir=./pruned_transducer_stateless6/exp/
|
||||
model_id=hubert_xtralarge_ll60k_finetune_ls960
|
||||
hubert_model_dir=${exp_dir}/hubert_models
|
||||
hubert_model=${hubert_model_dir}/${model_id}.pt
|
||||
hubert_model=${hubert_model_dir}/${teacher_model_id}.pt
|
||||
mkdir -p ${hubert_model_dir}
|
||||
# For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert
|
||||
if [ -f ${hubert_model} ]; then
|
||||
echo "hubert model alread exists."
|
||||
log "HuBERT model alread exists."
|
||||
else
|
||||
wget -c https://dl.fbaipublicfiles.com/hubert/${model_id} -P ${hubert_model}
|
||||
wget -c https://dl.fbaipublicfiles.com/hubert/${teacher_model_id}.pt -P ${hubert_model_dir}
|
||||
wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir}
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -d ./data/fbank ]; then
|
||||
echo "This script assumes ./data/fbank is already generated by prepare.sh"
|
||||
log "This script assumes ./data/fbank is already generated by prepare.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -eq 1 ]; then
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] && [ ! "$use_extracted_codebook" == "True" ]; then
|
||||
log "Stage 1: Verify that the downloaded HuBERT model is correct."
|
||||
# This stage is not directly used by codebook indexes extraction.
|
||||
# It is a method to "prove" that the downloaed hubert model
|
||||
# is inferenced in an correct way if WERs look like normal.
|
||||
# Expect WERs:
|
||||
# [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ]
|
||||
# [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ]
|
||||
./pruned_transducer_stateless6/hubert_decode.py
|
||||
./pruned_transducer_stateless6/hubert_decode.py --exp-dir $exp_dir
|
||||
fi
|
||||
|
||||
if [ $stage -eq 2 ]; then
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
# Analysis of disk usage:
|
||||
# With num_codebooks==8, each teacher embedding is quantized into
|
||||
# a sequence of eight 8-bit integers, i.e. only eight bytes are needed.
|
||||
@ -113,25 +139,61 @@ if [ $stage -eq 2 ]; then
|
||||
# During quantizer's training data(teacher embedding) and it's training,
|
||||
# only the first ONE GPU is used.
|
||||
# During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used.
|
||||
|
||||
if [ "$use_extracted_codebook" == "True" ]; then
|
||||
if [ ! "$teacher_model_id" == "hubert_xtralarge_ll60k_finetune_ls960" ]; then
|
||||
log "Currently we only uploaded codebook indexes from teacher model hubert_xtralarge_ll60k_finetune_ls960"
|
||||
exit 1
|
||||
fi
|
||||
mkdir -p $exp_dir/vq
|
||||
codebook_dir=$exp_dir/vq/$teacher_model_id
|
||||
mkdir -p codebook_dir
|
||||
codebook_download_dir=$exp_dir/download_codebook
|
||||
if [ -d $codebook_download_dir ]; then
|
||||
log "$codebook_download_dir exists, you should remove it first."
|
||||
exit 1
|
||||
fi
|
||||
log "Downloading extracted codebook indexes to $codebook_download_dir"
|
||||
# Make sure you have git-lfs installed (https://git-lfs.github.com)
|
||||
git lfs install
|
||||
git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
|
||||
|
||||
mkdir -p data/vq_fbank
|
||||
mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/
|
||||
mkdir -p $codebook_dir/splits4
|
||||
mv $codebook_download_dir/*.h5 $codebook_dir/splits4/
|
||||
log "Remove $codebook_download_dir"
|
||||
rm -rf $codebook_download_dir
|
||||
fi
|
||||
|
||||
./pruned_transducer_stateless6/extract_codebook_index.py \
|
||||
--full-libri False
|
||||
--full-libri $full_libri \
|
||||
--exp-dir $exp_dir \
|
||||
--embedding-layer 36 \
|
||||
--num-utts 1000 \
|
||||
--num-codebooks 8 \
|
||||
--max-duration 100 \
|
||||
--teacher-model-id $teacher_model_id \
|
||||
--use-extracted-codebook $use_extracted_codebook
|
||||
fi
|
||||
|
||||
if [ $stage -eq 3 ]; then
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
# Example training script.
|
||||
# Note: it's better to set spec-aug-time-warpi-factor=-1
|
||||
WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}')
|
||||
./pruned_transducer_stateless6/train.py \
|
||||
--manifest-dir ./data/vq_fbank \
|
||||
--master-port 12359 \
|
||||
--full-libri False \
|
||||
--full-libri $full_libri \
|
||||
--spec-aug-time-warp-factor -1 \
|
||||
--max-duration 300 \
|
||||
--world-size ${WORLD_SIZE} \
|
||||
--num-epochs 20
|
||||
--num-epochs 20 \
|
||||
--exp-dir $exp_dir \
|
||||
--enable-distillation True
|
||||
fi
|
||||
|
||||
if [ $stage -eq 4 ]; then
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
# Results should be similar to:
|
||||
# errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67
|
||||
# errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60
|
||||
@ -140,5 +202,6 @@ if [ $stage -eq 4 ]; then
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--max-duration 200 \
|
||||
--exp-dir ./pruned_transducer_stateless6/exp
|
||||
--exp-dir $exp_dir \
|
||||
--enable-distillation True
|
||||
fi
|
||||
|
@ -128,7 +128,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
@ -143,6 +143,13 @@ def get_parser():
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-distillation",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to eanble distillation.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
|
@ -24,7 +24,7 @@ import torch
|
||||
from vq_utils import CodebookIndexExtractor
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from hubert_xlarge import HubertXlargeFineTuned
|
||||
from icefall.utils import AttributeDict
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -38,6 +38,13 @@ def get_parser():
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-extracted-codebook",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use the extracted codebook indexes.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -71,9 +78,13 @@ def main():
|
||||
params.world_size = world_size
|
||||
|
||||
extractor = CodebookIndexExtractor(params=params)
|
||||
extractor.extract_and_save_embedding()
|
||||
extractor.train_quantizer()
|
||||
extractor.extract_codebook_indexes()
|
||||
if not params.use_extracted_codebook:
|
||||
extractor.extract_and_save_embedding()
|
||||
extractor.train_quantizer()
|
||||
extractor.extract_codebook_indexes()
|
||||
|
||||
extractor.reuse_manifests()
|
||||
extractor.join_manifests()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -41,7 +41,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
# For distiallation with codebook_indexes:
|
||||
# For distillation with codebook_indexes:
|
||||
|
||||
./pruned_transducer_stateless6/train.py \
|
||||
--manifest-dir ./data/vq_fbank \
|
||||
@ -300,6 +300,13 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-distillation",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to eanble distillation.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -372,7 +379,6 @@ def get_params() -> AttributeDict:
|
||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||
"env_info": get_env_info(),
|
||||
# parameters for distillation with codebook indexes.
|
||||
"enable_distiallation": True,
|
||||
"distillation_layer": 5, # 0-based index
|
||||
# Since output rate of hubert is 50, while that of encoder is 8,
|
||||
# two successive codebook_index are concatenated together.
|
||||
@ -394,7 +400,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
middle_output_layer=params.distillation_layer
|
||||
if params.enable_distiallation
|
||||
if params.enable_distillation
|
||||
else None,
|
||||
)
|
||||
return encoder
|
||||
@ -433,9 +439,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
num_codebooks=params.num_codebooks
|
||||
if params.enable_distiallation
|
||||
else 0,
|
||||
num_codebooks=params.num_codebooks if params.enable_distillation else 0,
|
||||
)
|
||||
return model
|
||||
|
||||
@ -615,7 +619,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
info = MetricsTracker()
|
||||
if is_training and params.enable_distiallation:
|
||||
if is_training and params.enable_distillation:
|
||||
codebook_indexes, _ = extract_codebook_indexes(batch)
|
||||
codebook_indexes = codebook_indexes.to(device)
|
||||
else:
|
||||
@ -645,7 +649,7 @@ def compute_loss(
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
)
|
||||
if is_training and params.enable_distiallation:
|
||||
if is_training and params.enable_distillation:
|
||||
assert codebook_loss is not None
|
||||
loss += params.codebook_loss_scale * codebook_loss
|
||||
|
||||
@ -661,7 +665,7 @@ def compute_loss(
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
if is_training and params.enable_distiallation:
|
||||
if is_training and params.enable_distillation:
|
||||
info["codebook_loss"] = codebook_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
@ -37,6 +37,7 @@ from icefall.utils import (
|
||||
setup_logger,
|
||||
)
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import MonoCut
|
||||
from lhotse.features.io import NumpyHdf5Writer
|
||||
|
||||
|
||||
@ -62,16 +63,15 @@ class CodebookIndexExtractor:
|
||||
setup_logger(f"{self.vq_dir}/log-vq_extraction")
|
||||
|
||||
def init_dirs(self):
|
||||
# vq_dir is the root dir for quantizer:
|
||||
# training data/ quantizer / extracted codebook indexes
|
||||
# vq_dir is the root dir for quantization, containing:
|
||||
# training data, trained quantizer, and extracted codebook indexes
|
||||
self.vq_dir = (
|
||||
self.params.exp_dir / f"vq/{self.params.teacher_model_id}/"
|
||||
)
|
||||
self.vq_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# manifest_dir for :
|
||||
# splited original manifests,
|
||||
# extracted codebook indexes and their related manifests
|
||||
# manifest_dir contains:
|
||||
# splited original manifests, extracted codebook indexes with related manifests # noqa
|
||||
self.manifest_dir = self.vq_dir / f"splits{self.params.world_size}"
|
||||
self.manifest_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -135,6 +135,7 @@ class CodebookIndexExtractor:
|
||||
logging.warn(warn_message)
|
||||
return
|
||||
|
||||
logging.info("Start to extract embeddings for training the quantizer.")
|
||||
total_cuts = 0
|
||||
with NumpyHdf5Writer(self.embedding_file_path) as writer:
|
||||
for batch_idx, batch in enumerate(self.quantizer_train_dl):
|
||||
@ -187,14 +188,15 @@ class CodebookIndexExtractor:
|
||||
return
|
||||
|
||||
assert self.embedding_file_path.exists()
|
||||
logging.info("Start to train quantizer.")
|
||||
trainer = quantization.QuantizerTrainer(
|
||||
dim=self.params.embedding_dim,
|
||||
bytes_per_frame=self.params.num_codebooks,
|
||||
device=self.params.device,
|
||||
)
|
||||
train, valid = quantization.read_hdf5_data(self.embedding_file_path)
|
||||
B = 512 # Minibatch size, this is very arbitrary, it's close to what we used
|
||||
# when we tuned this method.
|
||||
B = 512 # Minibatch size, this is very arbitrary,
|
||||
# it's close to what we used when we tuned this method.
|
||||
|
||||
def minibatch_generator(data: torch.Tensor, repeat: bool):
|
||||
assert 3 * B < data.shape[0]
|
||||
@ -222,18 +224,50 @@ class CodebookIndexExtractor:
|
||||
"""
|
||||
for subset in self.params.subsets:
|
||||
logging.info(f"About to split {subset}.")
|
||||
ori_manifest = f"./data/fbank/cuts_train-{subset}.json.gz"
|
||||
ori_manifest = (
|
||||
f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz"
|
||||
)
|
||||
split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}"
|
||||
os.system(f"{split_cmd}")
|
||||
|
||||
def join_manifests(self):
|
||||
"""
|
||||
Join the vq manifest to the original manifest according to cut id.
|
||||
"""
|
||||
logging.info("Start to join manifest files.")
|
||||
for subset in self.params.subsets:
|
||||
vq_manifest_path = (
|
||||
self.dst_manifest_dir
|
||||
/ f"librispeech_cuts_train-{subset}-vq.jsonl.gz"
|
||||
)
|
||||
ori_manifest_path = (
|
||||
self.ori_manifest_dir
|
||||
/ f"librispeech_cuts_train-{subset}.jsonl.gz"
|
||||
)
|
||||
dst_vq_manifest_path = (
|
||||
self.dst_manifest_dir
|
||||
/ f"librispeech_cuts_train-{subset}.jsonl.gz"
|
||||
)
|
||||
cuts_vq = load_manifest(vq_manifest_path)
|
||||
cuts_ori = load_manifest(ori_manifest_path)
|
||||
cuts_vq = cuts_vq.sort_like(cuts_ori)
|
||||
for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)):
|
||||
assert cut_vq.id == cut_ori.id
|
||||
cut_ori.codebook_indexes = cut_vq.codebook_indexes
|
||||
|
||||
CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path)
|
||||
logging.info(f"Processed {subset}.")
|
||||
logging.info(f"Saved to {dst_vq_manifest_path}.")
|
||||
|
||||
def merge_vq_manifests(self):
|
||||
"""
|
||||
Merge generated vq included manfiests and storage to self.dst_manifest_dir.
|
||||
"""
|
||||
for subset in self.params.subsets:
|
||||
vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-cuts_train-{subset}*.json.gz"
|
||||
vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz"
|
||||
dst_vq_manifest = (
|
||||
self.dst_manifest_dir / f"cuts_train-{subset}.json.gz"
|
||||
self.dst_manifest_dir
|
||||
/ f"librispeech_cuts_train-{subset}-vq.jsonl.gz"
|
||||
)
|
||||
if 1 == self.params.world_size:
|
||||
merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}"
|
||||
@ -273,7 +307,6 @@ class CodebookIndexExtractor:
|
||||
os.symlink(ori_manifest_path, dst_manifest_path)
|
||||
|
||||
def create_vq_fbank(self):
|
||||
self.reuse_manifests()
|
||||
self.merge_vq_manifests()
|
||||
|
||||
@cached_property
|
||||
@ -294,11 +327,13 @@ class CodebookIndexExtractor:
|
||||
|
||||
def load_ori_dl(self, subset):
|
||||
if self.params.world_size == 1:
|
||||
ori_manifest_path = f"./data/fbank/cuts_train-{subset}.json.gz"
|
||||
ori_manifest_path = (
|
||||
f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz"
|
||||
)
|
||||
else:
|
||||
ori_manifest_path = (
|
||||
self.manifest_dir
|
||||
/ f"cuts_train-{subset}.{self.params.manifest_index}.json.gz"
|
||||
/ f"librispeech_cuts_train-{subset}.{self.params.manifest_index}.jsonl.gz" # noqa
|
||||
)
|
||||
|
||||
cuts = load_manifest(ori_manifest_path)
|
||||
@ -311,6 +346,7 @@ class CodebookIndexExtractor:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def extract_codebook_indexes(self):
|
||||
logging.info("Start to extract codebook indexes.")
|
||||
if self.params.world_size == 1:
|
||||
self.extract_codebook_indexes_imp()
|
||||
else:
|
||||
@ -333,7 +369,7 @@ class CodebookIndexExtractor:
|
||||
def extract_codebook_indexes_imp(self):
|
||||
for subset in self.params.subsets:
|
||||
num_cuts = 0
|
||||
cuts = []
|
||||
new_cuts = []
|
||||
if self.params.world_size == 1:
|
||||
manifest_file_id = f"{subset}"
|
||||
else:
|
||||
@ -356,15 +392,23 @@ class CodebookIndexExtractor:
|
||||
assert len(cut_list) == codebook_indexes.shape[0]
|
||||
assert all(c.start == 0 for c in supervisions["cut"])
|
||||
|
||||
new_cut_list = []
|
||||
for idx, cut in enumerate(cut_list):
|
||||
cut.codebook_indexes = writer.store_array(
|
||||
new_cut = MonoCut(
|
||||
id=cut.id,
|
||||
start=cut.start,
|
||||
duration=cut.duration,
|
||||
channel=cut.channel,
|
||||
)
|
||||
new_cut.codebook_indexes = writer.store_array(
|
||||
key=cut.id,
|
||||
value=codebook_indexes[idx][: num_frames[idx]],
|
||||
frame_shift=0.02,
|
||||
temporal_dim=0,
|
||||
start=0,
|
||||
)
|
||||
cuts += cut_list
|
||||
new_cut_list.append(new_cut)
|
||||
new_cuts += new_cut_list
|
||||
num_cuts += len(cut_list)
|
||||
message = f"Processed {num_cuts} cuts from {subset}"
|
||||
if self.params.world_size > 1:
|
||||
@ -373,9 +417,9 @@ class CodebookIndexExtractor:
|
||||
|
||||
json_file_path = (
|
||||
self.manifest_dir
|
||||
/ f"with_codebook_indexes-cuts_train-{manifest_file_id}.json.gz"
|
||||
/ f"with_codebook_indexes-librispeech-cuts_train-{manifest_file_id}.jsonl.gz" # noqa
|
||||
)
|
||||
CutSet.from_cuts(cuts).to_json(json_file_path)
|
||||
CutSet.from_cuts(new_cuts).to_jsonl(json_file_path)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
Loading…
x
Reference in New Issue
Block a user