Upload extracted codebook indexes (#429)

* save only vq-related info to manifest

* support to join manifest files

* support using extracted codebook indexes

* fix doc

* minor fix

* add enable-distillation argument option, fix monir typos

* fix style

* fix typo
This commit is contained in:
Zengwei Yao 2022-06-21 19:16:59 +08:00 committed by GitHub
parent 91b2765cfd
commit d3daeaf5cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 189 additions and 60 deletions

119
egs/librispeech/ASR/distillation_with_hubert.sh Normal file → Executable file
View File

@ -1,3 +1,5 @@
#!/usr/bin/env bash
#
# A short introduction about distillation framework. # A short introduction about distillation framework.
# #
# A typical traditional distillation method is # A typical traditional distillation method is
@ -14,15 +16,15 @@
# teacher embeddings. # teacher embeddings.
# 3. a middle layer 6(1-based) out of total 6 layers is used to extract # 3. a middle layer 6(1-based) out of total 6 layers is used to extract
# student embeddings. # student embeddings.
# This is an example to do distillation with librispeech clean-100 subset.
# run with command:
# bash distillation_with_hubert.sh [0|1|2|3|4]
# #
# For example command # To directly download the extracted codebook indexes for model distillation, you can
# bash distillation_with_hubert.sh 0 # set stage=2, stop_stage=4, use_extracted_codebook=True
# will download hubert model. #
stage=$1 # To start from scratch, you can
# set stage=0, stop_stage=4, use_extracted_codebook=False
stage=0
stop_stage=4
# Set the GPUs available. # Set the GPUs available.
# This script requires at least one GPU. # This script requires at least one GPU.
@ -33,10 +35,35 @@ stage=$1
# export CUDA_VISIBLE_DEVICES="0" # export CUDA_VISIBLE_DEVICES="0"
# #
# Suppose GPU 2,3,4,5 are available. # Suppose GPU 2,3,4,5 are available.
export CUDA_VISIBLE_DEVICES="2,3,4,5" export CUDA_VISIBLE_DEVICES="0,1,2,3"
exp_dir=./pruned_transducer_stateless6/exp
mkdir -p $exp_dir
if [ $stage -eq 0 ]; then # full_libri can be "True" or "False"
# "True" -> use full librispeech dataset for distillation
# "False" -> use train-clean-100 subset for distillation
full_libri=False
# use_extracted_codebook can be "True" or "False"
# "True" -> stage 0 and stage 1 would be skipped,
# and directly download the extracted codebook indexes for distillation
# "False" -> start from scratch
use_extracted_codebook=False
# teacher_model_id can be one of
# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use.
# "hubert_xtralarge_ll60k" -> pretrained model without fintuing
teacher_model_id=hubert_xtralarge_ll60k_finetune_ls960
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" == "True" ]; then
log "Stage 0: Download HuBERT model"
# Preparation stage. # Preparation stage.
# Install fairseq according to: # Install fairseq according to:
@ -45,7 +72,7 @@ if [ $stage -eq 0 ]; then
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used. # commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used.
has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)") has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)")
if [ $has_fairseq == 'False' ]; then if [ $has_fairseq == 'False' ]; then
echo "Please install fairseq before running following stages" log "Please install fairseq before running following stages"
exit 1 exit 1
fi fi
@ -56,42 +83,41 @@ if [ $stage -eq 0 ]; then
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)") has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
if [ $has_quantization == 'False' ]; then if [ $has_quantization == 'False' ]; then
echo "Please install quantization before running following stages" log "Please install quantization before running following stages"
exit 1 exit 1
fi fi
echo "Download hubert model." log "Download HuBERT model."
# Parameters about model. # Parameters about model.
exp_dir=./pruned_transducer_stateless6/exp/
model_id=hubert_xtralarge_ll60k_finetune_ls960
hubert_model_dir=${exp_dir}/hubert_models hubert_model_dir=${exp_dir}/hubert_models
hubert_model=${hubert_model_dir}/${model_id}.pt hubert_model=${hubert_model_dir}/${teacher_model_id}.pt
mkdir -p ${hubert_model_dir} mkdir -p ${hubert_model_dir}
# For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert # For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert
if [ -f ${hubert_model} ]; then if [ -f ${hubert_model} ]; then
echo "hubert model alread exists." log "HuBERT model alread exists."
else else
wget -c https://dl.fbaipublicfiles.com/hubert/${model_id} -P ${hubert_model} wget -c https://dl.fbaipublicfiles.com/hubert/${teacher_model_id}.pt -P ${hubert_model_dir}
wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir} wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir}
fi fi
fi fi
if [ ! -d ./data/fbank ]; then if [ ! -d ./data/fbank ]; then
echo "This script assumes ./data/fbank is already generated by prepare.sh" log "This script assumes ./data/fbank is already generated by prepare.sh"
exit 1 exit 1
fi fi
if [ $stage -eq 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] && [ ! "$use_extracted_codebook" == "True" ]; then
log "Stage 1: Verify that the downloaded HuBERT model is correct."
# This stage is not directly used by codebook indexes extraction. # This stage is not directly used by codebook indexes extraction.
# It is a method to "prove" that the downloaed hubert model # It is a method to "prove" that the downloaed hubert model
# is inferenced in an correct way if WERs look like normal. # is inferenced in an correct way if WERs look like normal.
# Expect WERs: # Expect WERs:
# [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ] # [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ]
# [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ] # [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ]
./pruned_transducer_stateless6/hubert_decode.py ./pruned_transducer_stateless6/hubert_decode.py --exp-dir $exp_dir
fi fi
if [ $stage -eq 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# Analysis of disk usage: # Analysis of disk usage:
# With num_codebooks==8, each teacher embedding is quantized into # With num_codebooks==8, each teacher embedding is quantized into
# a sequence of eight 8-bit integers, i.e. only eight bytes are needed. # a sequence of eight 8-bit integers, i.e. only eight bytes are needed.
@ -113,25 +139,61 @@ if [ $stage -eq 2 ]; then
# During quantizer's training data(teacher embedding) and it's training, # During quantizer's training data(teacher embedding) and it's training,
# only the first ONE GPU is used. # only the first ONE GPU is used.
# During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used. # During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used.
if [ "$use_extracted_codebook" == "True" ]; then
if [ ! "$teacher_model_id" == "hubert_xtralarge_ll60k_finetune_ls960" ]; then
log "Currently we only uploaded codebook indexes from teacher model hubert_xtralarge_ll60k_finetune_ls960"
exit 1
fi
mkdir -p $exp_dir/vq
codebook_dir=$exp_dir/vq/$teacher_model_id
mkdir -p codebook_dir
codebook_download_dir=$exp_dir/download_codebook
if [ -d $codebook_download_dir ]; then
log "$codebook_download_dir exists, you should remove it first."
exit 1
fi
log "Downloading extracted codebook indexes to $codebook_download_dir"
# Make sure you have git-lfs installed (https://git-lfs.github.com)
git lfs install
git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
mkdir -p data/vq_fbank
mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/
mkdir -p $codebook_dir/splits4
mv $codebook_download_dir/*.h5 $codebook_dir/splits4/
log "Remove $codebook_download_dir"
rm -rf $codebook_download_dir
fi
./pruned_transducer_stateless6/extract_codebook_index.py \ ./pruned_transducer_stateless6/extract_codebook_index.py \
--full-libri False --full-libri $full_libri \
--exp-dir $exp_dir \
--embedding-layer 36 \
--num-utts 1000 \
--num-codebooks 8 \
--max-duration 100 \
--teacher-model-id $teacher_model_id \
--use-extracted-codebook $use_extracted_codebook
fi fi
if [ $stage -eq 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
# Example training script. # Example training script.
# Note: it's better to set spec-aug-time-warpi-factor=-1 # Note: it's better to set spec-aug-time-warpi-factor=-1
WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}') WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}')
./pruned_transducer_stateless6/train.py \ ./pruned_transducer_stateless6/train.py \
--manifest-dir ./data/vq_fbank \ --manifest-dir ./data/vq_fbank \
--master-port 12359 \ --master-port 12359 \
--full-libri False \ --full-libri $full_libri \
--spec-aug-time-warp-factor -1 \ --spec-aug-time-warp-factor -1 \
--max-duration 300 \ --max-duration 300 \
--world-size ${WORLD_SIZE} \ --world-size ${WORLD_SIZE} \
--num-epochs 20 --num-epochs 20 \
--exp-dir $exp_dir \
--enable-distillation True
fi fi
if [ $stage -eq 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
# Results should be similar to: # Results should be similar to:
# errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67 # errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67
# errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60 # errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60
@ -140,5 +202,6 @@ if [ $stage -eq 4 ]; then
--epoch 20 \ --epoch 20 \
--avg 10 \ --avg 10 \
--max-duration 200 \ --max-duration 200 \
--exp-dir ./pruned_transducer_stateless6/exp --exp-dir $exp_dir \
--enable-distillation True
fi fi

View File

@ -128,7 +128,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=True,
help="Whether to load averaged model. Currently it only supports " help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model " "using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`." "over the epoch range from `epoch-avg` (excluded) to `epoch`."
@ -143,6 +143,13 @@ def get_parser():
help="The experiment dir", help="The experiment dir",
) )
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,

View File

@ -24,7 +24,7 @@ import torch
from vq_utils import CodebookIndexExtractor from vq_utils import CodebookIndexExtractor
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from hubert_xlarge import HubertXlargeFineTuned from hubert_xlarge import HubertXlargeFineTuned
from icefall.utils import AttributeDict from icefall.utils import AttributeDict, str2bool
def get_parser(): def get_parser():
@ -38,6 +38,13 @@ def get_parser():
help="The experiment dir", help="The experiment dir",
) )
parser.add_argument(
"--use-extracted-codebook",
type=str2bool,
default=False,
help="Whether to use the extracted codebook indexes.",
)
return parser return parser
@ -71,9 +78,13 @@ def main():
params.world_size = world_size params.world_size = world_size
extractor = CodebookIndexExtractor(params=params) extractor = CodebookIndexExtractor(params=params)
extractor.extract_and_save_embedding() if not params.use_extracted_codebook:
extractor.train_quantizer() extractor.extract_and_save_embedding()
extractor.extract_codebook_indexes() extractor.train_quantizer()
extractor.extract_codebook_indexes()
extractor.reuse_manifests()
extractor.join_manifests()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -41,7 +41,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550
# For distiallation with codebook_indexes: # For distillation with codebook_indexes:
./pruned_transducer_stateless6/train.py \ ./pruned_transducer_stateless6/train.py \
--manifest-dir ./data/vq_fbank \ --manifest-dir ./data/vq_fbank \
@ -300,6 +300,13 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
return parser return parser
@ -372,7 +379,6 @@ def get_params() -> AttributeDict:
"model_warm_step": 3000, # arg given to model, not for lrate "model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(), "env_info": get_env_info(),
# parameters for distillation with codebook indexes. # parameters for distillation with codebook indexes.
"enable_distiallation": True,
"distillation_layer": 5, # 0-based index "distillation_layer": 5, # 0-based index
# Since output rate of hubert is 50, while that of encoder is 8, # Since output rate of hubert is 50, while that of encoder is 8,
# two successive codebook_index are concatenated together. # two successive codebook_index are concatenated together.
@ -394,7 +400,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
middle_output_layer=params.distillation_layer middle_output_layer=params.distillation_layer
if params.enable_distiallation if params.enable_distillation
else None, else None,
) )
return encoder return encoder
@ -433,9 +439,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
num_codebooks=params.num_codebooks num_codebooks=params.num_codebooks if params.enable_distillation else 0,
if params.enable_distiallation
else 0,
) )
return model return model
@ -615,7 +619,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
info = MetricsTracker() info = MetricsTracker()
if is_training and params.enable_distiallation: if is_training and params.enable_distillation:
codebook_indexes, _ = extract_codebook_indexes(batch) codebook_indexes, _ = extract_codebook_indexes(batch)
codebook_indexes = codebook_indexes.to(device) codebook_indexes = codebook_indexes.to(device)
else: else:
@ -645,7 +649,7 @@ def compute_loss(
params.simple_loss_scale * simple_loss params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss + pruned_loss_scale * pruned_loss
) )
if is_training and params.enable_distiallation: if is_training and params.enable_distillation:
assert codebook_loss is not None assert codebook_loss is not None
loss += params.codebook_loss_scale * codebook_loss loss += params.codebook_loss_scale * codebook_loss
@ -661,7 +665,7 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if is_training and params.enable_distiallation: if is_training and params.enable_distillation:
info["codebook_loss"] = codebook_loss.detach().cpu().item() info["codebook_loss"] = codebook_loss.detach().cpu().item()
return loss, info return loss, info

View File

@ -37,6 +37,7 @@ from icefall.utils import (
setup_logger, setup_logger,
) )
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
from lhotse.cut import MonoCut
from lhotse.features.io import NumpyHdf5Writer from lhotse.features.io import NumpyHdf5Writer
@ -62,16 +63,15 @@ class CodebookIndexExtractor:
setup_logger(f"{self.vq_dir}/log-vq_extraction") setup_logger(f"{self.vq_dir}/log-vq_extraction")
def init_dirs(self): def init_dirs(self):
# vq_dir is the root dir for quantizer: # vq_dir is the root dir for quantization, containing:
# training data/ quantizer / extracted codebook indexes # training data, trained quantizer, and extracted codebook indexes
self.vq_dir = ( self.vq_dir = (
self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" self.params.exp_dir / f"vq/{self.params.teacher_model_id}/"
) )
self.vq_dir.mkdir(parents=True, exist_ok=True) self.vq_dir.mkdir(parents=True, exist_ok=True)
# manifest_dir for : # manifest_dir contains:
# splited original manifests, # splited original manifests, extracted codebook indexes with related manifests # noqa
# extracted codebook indexes and their related manifests
self.manifest_dir = self.vq_dir / f"splits{self.params.world_size}" self.manifest_dir = self.vq_dir / f"splits{self.params.world_size}"
self.manifest_dir.mkdir(parents=True, exist_ok=True) self.manifest_dir.mkdir(parents=True, exist_ok=True)
@ -135,6 +135,7 @@ class CodebookIndexExtractor:
logging.warn(warn_message) logging.warn(warn_message)
return return
logging.info("Start to extract embeddings for training the quantizer.")
total_cuts = 0 total_cuts = 0
with NumpyHdf5Writer(self.embedding_file_path) as writer: with NumpyHdf5Writer(self.embedding_file_path) as writer:
for batch_idx, batch in enumerate(self.quantizer_train_dl): for batch_idx, batch in enumerate(self.quantizer_train_dl):
@ -187,14 +188,15 @@ class CodebookIndexExtractor:
return return
assert self.embedding_file_path.exists() assert self.embedding_file_path.exists()
logging.info("Start to train quantizer.")
trainer = quantization.QuantizerTrainer( trainer = quantization.QuantizerTrainer(
dim=self.params.embedding_dim, dim=self.params.embedding_dim,
bytes_per_frame=self.params.num_codebooks, bytes_per_frame=self.params.num_codebooks,
device=self.params.device, device=self.params.device,
) )
train, valid = quantization.read_hdf5_data(self.embedding_file_path) train, valid = quantization.read_hdf5_data(self.embedding_file_path)
B = 512 # Minibatch size, this is very arbitrary, it's close to what we used B = 512 # Minibatch size, this is very arbitrary,
# when we tuned this method. # it's close to what we used when we tuned this method.
def minibatch_generator(data: torch.Tensor, repeat: bool): def minibatch_generator(data: torch.Tensor, repeat: bool):
assert 3 * B < data.shape[0] assert 3 * B < data.shape[0]
@ -222,18 +224,50 @@ class CodebookIndexExtractor:
""" """
for subset in self.params.subsets: for subset in self.params.subsets:
logging.info(f"About to split {subset}.") logging.info(f"About to split {subset}.")
ori_manifest = f"./data/fbank/cuts_train-{subset}.json.gz" ori_manifest = (
f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz"
)
split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}"
os.system(f"{split_cmd}") os.system(f"{split_cmd}")
def join_manifests(self):
"""
Join the vq manifest to the original manifest according to cut id.
"""
logging.info("Start to join manifest files.")
for subset in self.params.subsets:
vq_manifest_path = (
self.dst_manifest_dir
/ f"librispeech_cuts_train-{subset}-vq.jsonl.gz"
)
ori_manifest_path = (
self.ori_manifest_dir
/ f"librispeech_cuts_train-{subset}.jsonl.gz"
)
dst_vq_manifest_path = (
self.dst_manifest_dir
/ f"librispeech_cuts_train-{subset}.jsonl.gz"
)
cuts_vq = load_manifest(vq_manifest_path)
cuts_ori = load_manifest(ori_manifest_path)
cuts_vq = cuts_vq.sort_like(cuts_ori)
for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)):
assert cut_vq.id == cut_ori.id
cut_ori.codebook_indexes = cut_vq.codebook_indexes
CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path)
logging.info(f"Processed {subset}.")
logging.info(f"Saved to {dst_vq_manifest_path}.")
def merge_vq_manifests(self): def merge_vq_manifests(self):
""" """
Merge generated vq included manfiests and storage to self.dst_manifest_dir. Merge generated vq included manfiests and storage to self.dst_manifest_dir.
""" """
for subset in self.params.subsets: for subset in self.params.subsets:
vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-cuts_train-{subset}*.json.gz" vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz"
dst_vq_manifest = ( dst_vq_manifest = (
self.dst_manifest_dir / f"cuts_train-{subset}.json.gz" self.dst_manifest_dir
/ f"librispeech_cuts_train-{subset}-vq.jsonl.gz"
) )
if 1 == self.params.world_size: if 1 == self.params.world_size:
merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}"
@ -273,7 +307,6 @@ class CodebookIndexExtractor:
os.symlink(ori_manifest_path, dst_manifest_path) os.symlink(ori_manifest_path, dst_manifest_path)
def create_vq_fbank(self): def create_vq_fbank(self):
self.reuse_manifests()
self.merge_vq_manifests() self.merge_vq_manifests()
@cached_property @cached_property
@ -294,11 +327,13 @@ class CodebookIndexExtractor:
def load_ori_dl(self, subset): def load_ori_dl(self, subset):
if self.params.world_size == 1: if self.params.world_size == 1:
ori_manifest_path = f"./data/fbank/cuts_train-{subset}.json.gz" ori_manifest_path = (
f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz"
)
else: else:
ori_manifest_path = ( ori_manifest_path = (
self.manifest_dir self.manifest_dir
/ f"cuts_train-{subset}.{self.params.manifest_index}.json.gz" / f"librispeech_cuts_train-{subset}.{self.params.manifest_index}.jsonl.gz" # noqa
) )
cuts = load_manifest(ori_manifest_path) cuts = load_manifest(ori_manifest_path)
@ -311,6 +346,7 @@ class CodebookIndexExtractor:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def extract_codebook_indexes(self): def extract_codebook_indexes(self):
logging.info("Start to extract codebook indexes.")
if self.params.world_size == 1: if self.params.world_size == 1:
self.extract_codebook_indexes_imp() self.extract_codebook_indexes_imp()
else: else:
@ -333,7 +369,7 @@ class CodebookIndexExtractor:
def extract_codebook_indexes_imp(self): def extract_codebook_indexes_imp(self):
for subset in self.params.subsets: for subset in self.params.subsets:
num_cuts = 0 num_cuts = 0
cuts = [] new_cuts = []
if self.params.world_size == 1: if self.params.world_size == 1:
manifest_file_id = f"{subset}" manifest_file_id = f"{subset}"
else: else:
@ -356,15 +392,23 @@ class CodebookIndexExtractor:
assert len(cut_list) == codebook_indexes.shape[0] assert len(cut_list) == codebook_indexes.shape[0]
assert all(c.start == 0 for c in supervisions["cut"]) assert all(c.start == 0 for c in supervisions["cut"])
new_cut_list = []
for idx, cut in enumerate(cut_list): for idx, cut in enumerate(cut_list):
cut.codebook_indexes = writer.store_array( new_cut = MonoCut(
id=cut.id,
start=cut.start,
duration=cut.duration,
channel=cut.channel,
)
new_cut.codebook_indexes = writer.store_array(
key=cut.id, key=cut.id,
value=codebook_indexes[idx][: num_frames[idx]], value=codebook_indexes[idx][: num_frames[idx]],
frame_shift=0.02, frame_shift=0.02,
temporal_dim=0, temporal_dim=0,
start=0, start=0,
) )
cuts += cut_list new_cut_list.append(new_cut)
new_cuts += new_cut_list
num_cuts += len(cut_list) num_cuts += len(cut_list)
message = f"Processed {num_cuts} cuts from {subset}" message = f"Processed {num_cuts} cuts from {subset}"
if self.params.world_size > 1: if self.params.world_size > 1:
@ -373,9 +417,9 @@ class CodebookIndexExtractor:
json_file_path = ( json_file_path = (
self.manifest_dir self.manifest_dir
/ f"with_codebook_indexes-cuts_train-{manifest_file_id}.json.gz" / f"with_codebook_indexes-librispeech-cuts_train-{manifest_file_id}.jsonl.gz" # noqa
) )
CutSet.from_cuts(cuts).to_json(json_file_path) CutSet.from_cuts(new_cuts).to_jsonl(json_file_path)
@torch.no_grad() @torch.no_grad()