distillation with hubert

This commit is contained in:
Guo Liyong 2022-05-26 22:00:40 +08:00
parent 1f37eb5d0c
commit d8f68abff8
12 changed files with 1170 additions and 36 deletions

View File

@ -0,0 +1,95 @@
stage=$1
export CUDA_VISIBLE_DEVICES="2,3,4,5"
if [ $stage -eq 0 ]; then
# Preparation stage.
# Install fairseq according to:
# https://github.com/pytorch/fairseq
# when testing this code:
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used.
#
# Install quantization toolkit:
# pip install git+https://github.com/danpovey/quantization.git@master
# when testing this code:
# commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used.
echo "Download hubert model."
# Parameters about model.
exp_dir=./pruned_transducer_stateless6/exp/
model_id=hubert_xtralarge_ll60k_finetune_ls960
hubert_model_dir=${exp_dir}/hubert_models
hubert_model=${hubert_model_dir}/${model_id}.pt
mkdir -p ${hubert_model_dir}
# For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert
wget -c https://dl.fbaipublicfiles.com/hubert/${model_id} -P ${hubert_model_dir}
wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir}
fi
if [ ! -d ./data/fbank ]; then
echo "This script assumes ./data/fbank is already generated by prepare.sh"
exit 0
fi
if [ $stage -eq 1 ]; then
# This stage is not directly used by codebook indexes extraction.
# It is a method to "prove" that the downloaed hubert model
# is inferenced in an correct way if WERs look like normal.
# Expect WERs:
# [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ]
# [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ]
./pruned_transducer_stateless6/hubert_decode.py
fi
if [ $stage -eq 2 ]; then
# Analysis of disk usage:
# With num_codebooks==8, each teacher embedding is quantized into
# a sequence of eight 8-bit integers, i.e. only eight bytes are needed.
# Training dataset including clean-100h with speed perturb 0.9 and 1.1 has 300 hours.
# The output frame rates of Hubert is 50 per second.
# Theoretically, 412M = 300 * 3600 * 50 * 8 / 1024 / 1024 is needed.
# The actual size of all "*.h5" files storaging codebook index is 450M.
# I think the extra "48M" usage is some meta information.
# Time consumption analysis:
# For quantizer training data(teacher embedding) extraction, only 1000 utts from clean-100 are used.
# Together with quantizer training, no more than 20 minutes will be used.
#
# For codebook indexes extraction,
# with two pieces of NVIDIA A100 gpus, around three hours needed to process 300 hours training data,
# i.e. clean-100 with speed purteb 0.9 and 1.1.
# GPU usage:
# During quantizer's training data(teacher embedding) and it's training,
# only the first ONE GPU is used.
# During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used.
./pruned_transducer_stateless6/extract_codebook_index.py \
--full-libri False
fi
if [ $stage -eq 3 ]; then
# Example training script.
# Note: it's better to set spec-aug-time-warpi-factor=-1
WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}')
./pruned_transducer_stateless6/train.py \
--manifest-dir ./data/vq_fbank \
--master-port 12359 \
--full-libri False \
--spec-aug-time-warp-factor -1 \
--max-duration 300 \
--world-size ${WORLD_SIZE} \
--num-epochs 20
fi
if [ $stage -eq 4 ]; then
# Results should be similar to:
# errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67
# errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60
./pruned_transducer_stateless6/decode.py \
--decoding-method "modified_beam_search" \
--epoch 20 \
--avg 10 \
--max-duration 200 \
--exp-dir ./pruned_transducer_stateless6/exp
fi

View File

@ -18,7 +18,7 @@
import copy
import math
import warnings
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import torch
from encoder_interface import EncoderInterface
@ -61,6 +61,7 @@ class Conformer(EncoderInterface):
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
middle_output_layer: int = None, # 0-based layer index
) -> None:
super(Conformer, self).__init__()
@ -86,11 +87,25 @@ class Conformer(EncoderInterface):
layer_dropout,
cnn_module_kernel,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
output_layers = []
if middle_output_layer is not None:
assert (
middle_output_layer >= 0
and middle_output_layer < num_encoder_layers
)
output_layers.append(middle_output_layer)
# The last layer is always needed.
output_layers.append(num_encoder_layers - 1)
self.encoder = ConformerEncoder(
encoder_layer, num_encoder_layers, output_layers=output_layers
)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[List[torch.Tensor], torch.Tensor]:
"""
Args:
x:
@ -122,13 +137,11 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(
layer_results = self.encoder(
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
return layer_results, lengths
class ConformerEncoderLayer(nn.Module):
@ -279,12 +292,18 @@ class ConformerEncoder(nn.Module):
>>> out = conformer_encoder(src, pos_emb)
"""
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
def __init__(
self,
encoder_layer: nn.Module,
num_layers: int,
output_layers: List[int],
) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
self.output_layers = output_layers
def forward(
self,
@ -293,7 +312,7 @@ class ConformerEncoder(nn.Module):
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tensor:
) -> List[Tensor]:
r"""Pass the input through the encoder layers in turn.
Args:
@ -312,6 +331,7 @@ class ConformerEncoder(nn.Module):
"""
output = src
layer_results = []
for i, mod in enumerate(self.layers):
output = mod(
output,
@ -320,8 +340,11 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
if i in self.output_layers:
# (T, N, C) --> (N, T, C)
layer_results.append(output.permute(1, 0, 2))
return output
return layer_results
class RelPositionalEncoding(torch.nn.Module):

View File

@ -19,36 +19,36 @@
"""
Usage:
(1) greedy search
./pruned_transducer_stateless4/decode.py \
./pruned_transducer_stateless6/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--exp-dir ./pruned_transducer_stateless6/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless4/decode.py \
./pruned_transducer_stateless6/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--exp-dir ./pruned_transducer_stateless6/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless4/decode.py \
./pruned_transducer_stateless6/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--exp-dir ./pruned_transducer_stateless6/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless4/decode.py \
./pruned_transducer_stateless6/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--exp-dir ./pruned_transducer_stateless6/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
@ -139,7 +139,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless4/exp",
default="pruned_transducer_stateless6/exp",
help="The experiment dir",
)
@ -260,9 +260,10 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
layer_results, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out = layer_results[-1]
hyps = []
if params.decoding_method == "fast_beam_search":

View File

@ -0,0 +1,80 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import torch
from vq_utils import CodebookIndexExtractor
from asr_datamodule import LibriSpeechAsrDataModule
from hubert_xlarge import HubertXlargeFineTuned
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--exp-dir",
type=Path,
default="pruned_transducer_stateless6/exp/",
help="The experiment dir",
)
return parser
def get_world_size():
warn_message = (
"It's better to use GPU to extrac codebook indices"
"Please set with commonds like: export CUDA_VISIBLE_DEVICES=0,1,2,3"
)
assert (
torch.cuda.is_available() and "CUDA_VISIBLE_DEVICES" in os.environ
), warn_message
world_size = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
assert world_size > 0, warn_message
return world_size
def main():
world_size = get_world_size()
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
HubertXlargeFineTuned.add_arguments(parser)
CodebookIndexExtractor.add_arguments(parser)
args = parser.parse_args()
params = AttributeDict()
params.update(vars(args))
# reset some parameters needed by hubert.
params.update(HubertXlargeFineTuned.get_params())
params.device = torch.device("cuda", 0)
params.world_size = world_size
extractor = CodebookIndexExtractor(params=params)
extractor.extract_and_save_embedding()
extractor.train_quantizer()
extractor.extract_codebook_indexes()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,205 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from hubert_xlarge import HubertXlargeFineTuned
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--exp-dir",
type=Path,
default="pruned_transducer_stateless6/exp/",
help="The experiment dir",
)
return parser
def decode_dataset(
dl: torch.utils.data.DataLoader,
hubert_model: HubertXlargeFineTuned,
params: AttributeDict,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
hyps = hubert_model.ctc_greedy_search(batch)
texts = batch["supervisions"]["text"]
assert len(hyps) == len(texts)
this_batch = []
for hyp_text, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
hyp_words = hyp_text.split()
this_batch.append((ref_words, hyp_words))
results["ctc_greedy_search"].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 20 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results)
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
HubertXlargeFineTuned.add_arguments(parser)
args = parser.parse_args()
params = AttributeDict()
params.update(vars(args))
# reset some parameters needed by hubert.
params.update(HubertXlargeFineTuned.get_params())
params.res_dir = (
params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}"
)
setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search")
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
params.device = device
hubert_model = HubertXlargeFineTuned(params)
librispeech = LibriSpeechAsrDataModule(params)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
hubert_model=hubert_model,
params=params,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,204 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import Dict, List, Tuple
import torch
from fairseq import (
checkpoint_utils,
tasks,
utils,
)
from fairseq.data.data_utils import post_process
from omegaconf import OmegaConf
from icefall.utils import AttributeDict
def _load_hubert_model(params: AttributeDict):
cfg_task = OmegaConf.create(
{
"_name": "hubert_pretraining",
"single_target": True,
"fine_tuning": True,
"data": str(params.hubert_model_dir),
}
)
model_path = Path(params.hubert_model_dir) / (
params.teacher_model_id + ".pt"
)
task = tasks.setup_task(cfg_task)
processor = task.target_dictionary
models, saved_cfg = checkpoint_utils.load_model_ensemble(
utils.split_paths(str(model_path), separator="\\"),
arg_overrides={},
strict=True,
suffix="",
num_shards=1,
)
model = models[0]
model.to(params.device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
return model, processor
class HubertXlargeFineTuned:
"""
A wrapper of hubert extra larger finedtuned model.
A teacher model responsible for:
1. load teacher model
2. extracting embeddings to train quantizer.
3. extract codebook indices
4. verify it's performance with ctc_greedy_search method.
"""
def __init__(self, params: AttributeDict):
self.model, self.processor = _load_hubert_model(params)
self.w2v_model = self.model.w2v_encoder.w2v_model
self.params = params
@staticmethod
def get_params() -> AttributeDict:
"""Return a dict containing parameters defined in other modules.
Their default value conflits to hubert's requirements so they are reset as following.
"""
params = AttributeDict(
{
# parameters defined in asr_datamodule.py
"input_strategy": "AudioSamples",
"enable_musan": False,
"enable_spec_aug": False,
"return_cuts": True,
"drop_last": False,
# parameters used by quantizer
"embedding_dim": 1280,
}
)
return params
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
# Options about model loading.
parser.add_argument(
"--hubert-model-dir",
type=Path,
default="./pruned_transducer_stateless6/exp/hubert_models/",
help="path to save downloaded hubert models.",
)
parser.add_argument(
"--teacher-model-id",
type=str,
default="hubert_xtralarge_ll60k_finetune_ls960",
help="""could be one of:
[
"hubert_xtralarge_ll60k_finetune_ls960", # fintuned model.
"hubert_xtralarge_ll60k.pt", # pretrained model without fintuing.
]""",
)
parser.add_argument(
"--total-layers",
type=int,
default=48,
)
# Modified from HubertModel.forward to extract all middle layers output
def extract_layers_result(
self,
batch: Dict,
) -> Dict[str, torch.Tensor]:
"""
Extract activations from all layers.
"""
features = batch["inputs"]
# corresponding task.normalize in fairseq
features = torch.nn.functional.layer_norm(features, features.shape)
supervisions = batch["supervisions"]
num_samples = supervisions["num_samples"]
B, T = features.shape
padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape(
[-1, 1]
)
padding_mask = padding_mask.to(self.params.device)
features = features.to(self.params.device)
features = self.w2v_model.forward_features(features)
features = features.transpose(1, 2)
features = self.w2v_model.layer_norm(features)
if padding_mask is not None:
padding_mask = self.w2v_model.forward_padding_mask(
features, padding_mask
)
if self.w2v_model.post_extract_proj is not None:
features = self.w2v_model.post_extract_proj(features)
_, layer_results = self.w2v_model.encoder(
features,
padding_mask=padding_mask,
)
return layer_results
def extract_embedding(self, batch) -> Tuple[torch.tensor, List[int]]:
supervisions = batch["supervisions"]
cut_list = supervisions["cut"]
assert all(c.start == 0 for c in cut_list)
layer_results = self.extract_layers_result(batch)
embeddings = layer_results[self.params.embedding_layer - 1][0]
encoder_embedding = embeddings.transpose(0, 1) # N, T, C
N = encoder_embedding.shape[0]
assert len(cut_list) == N
# 320 is from: 16,000 / 50 = sample_rate / hbuert output frame rate
num_frames = [
supervisions["num_samples"][i].item() // 320 for i in range(N)
]
return encoder_embedding, num_frames
def ctc_greedy_search(self, batch):
"""
Mainly used to verify hubert model is used correctly.
"""
layer_results = self.extract_layers_result(batch=batch)
encoder_out = self.w2v_model.encoder.layer_norm(
layer_results[self.params.total_layers - 1][0]
)
encoder_out = self.model.w2v_encoder.proj(encoder_out.transpose(0, 1))
toks = encoder_out.argmax(dim=-1)
blank = 0
toks = [tok.unique_consecutive() for tok in toks]
hyps = [
self.processor.string(tok[tok != blank].int().cpu()) for tok in toks
]
hyps = [post_process(hyp, "letter") for hyp in hyps]
return hyps

View File

@ -23,6 +23,8 @@ from scaling import ScaledLinear
from icefall.utils import add_sos
from quantization.prediction import JointCodebookLoss
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
@ -38,6 +40,7 @@ class Transducer(nn.Module):
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
num_codebooks: int = 0,
):
"""
Args:
@ -55,6 +58,8 @@ class Transducer(nn.Module):
(N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output
contains unnormalized probs, i.e., not processed by log-softmax.
num_codebooks:
Used by distillation loss.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
@ -68,6 +73,10 @@ class Transducer(nn.Module):
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=encoder_dim, num_codebooks=num_codebooks
)
def forward(
self,
@ -78,6 +87,7 @@ class Transducer(nn.Module):
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
codebook_indexes: torch.Tensor = None,
) -> torch.Tensor:
"""
Args:
@ -101,6 +111,8 @@ class Transducer(nn.Module):
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
codebook_indexes:
codebook_indexes extracted from a teacher model.
Returns:
Return the transducer loss.
@ -116,7 +128,22 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
layer_results, x_lens = self.encoder(x, x_lens, warmup=warmup)
encoder_out = layer_results[-1]
middle_layer_output = layer_results[0]
if self.training and codebook_indexes is not None:
assert hasattr(self, "codebook_loss_net")
if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
codebook_indexes = self.concat_sucessive_codebook_indexes(
middle_layer_output, codebook_indexes
)
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
else:
# when codebook index is not available.
codebook_loss = None
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
@ -191,4 +218,32 @@ class Transducer(nn.Module):
reduction="sum",
)
return (simple_loss, pruned_loss)
return (simple_loss, pruned_loss, codebook_loss)
@staticmethod
def concat_sucessive_codebook_indexes(
middle_layer_output, codebook_indexes
):
# Output rate of hubert is 50 frames per second,
# while that of current encoder is 25.
# Following code handling two issues:
# 1.
# Roughly speaking, to generate another frame output,
# hubert needes extra two frames,
# while current encoder needs extra four frames.
# Suppose there are only extra three frames provided,
# hubert will generate another frame while current encoder does nothing.
# 2.
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
# learns from 50 frames teacher output, two successive frames of teacher model
# output is concatenated together.
t_expected = middle_layer_output.shape[1]
N, T, C = codebook_indexes.shape
# Handling issue 1.
if T >= t_expected * 2:
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
# Handling issue 2.
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes

View File

@ -20,7 +20,7 @@
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless4/test_model.py
python ./pruned_transducer_stateless6/test_model.py
"""
import torch
@ -33,6 +33,7 @@ def test_model():
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.enable_distiallation = False
model = get_transducer_model(params)

View File

@ -22,25 +22,36 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless4/train.py \
./pruned_transducer_stateless6/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \
--exp-dir pruned_transducer_stateless6/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless4/train.py \
./pruned_transducer_stateless6/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--exp-dir pruned_transducer_stateless6/exp \
--full-libri 1 \
--max-duration 550
# For distiallation with codebook_indexes:
./pruned_transducer_stateless6/train.py \
--manifest-dir ./data/vq_fbank \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless6/exp \
--full-libri 0 \
--max-duration 300
"""
@ -62,9 +73,10 @@ from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.cut import Cut, MonoCut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from lhotse.dataset.collation import collate_custom_field
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
@ -143,7 +155,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="pruned_transducer_stateless6/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -223,6 +235,13 @@ def get_parser():
"with this parameter before adding to the final loss.",
)
parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="The scale of codebook loss.",
)
parser.add_argument(
"--seed",
type=int,
@ -352,6 +371,13 @@ def get_params() -> AttributeDict:
# parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(),
# parameters for distillation with codebook indexes.
"enable_distiallation": True,
"distillation_layer": 5, # 0-based index
# Since output rate of hubert is 50, while that of encoder is 8,
# two successive codebook_index are concatenated together.
# Detailed in function Transducer::concat_sucessive_codebook_indexes.
"num_codebooks": 16, # used to construct distillation loss
}
)
@ -367,6 +393,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
middle_output_layer=params.distillation_layer
if params.enable_distiallation
else None,
)
return encoder
@ -404,6 +433,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
num_codebooks=params.num_codebooks
if params.enable_distiallation
else 0,
)
return model
@ -527,6 +559,18 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
def extract_codebook_indexes(batch):
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
codebook_indexes, codebook_indexes_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indexes", pad_value=-100
)
return codebook_indexes, codebook_indexes_lens
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
@ -570,8 +614,15 @@ def compute_loss(
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
info = MetricsTracker()
if is_training and params.enable_distiallation:
codebook_indexes, _ = extract_codebook_indexes(batch)
codebook_indexes = codebook_indexes.to(device)
else:
codebook_indexes = None
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
simple_loss, pruned_loss, codebook_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -579,6 +630,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
codebook_indexes=codebook_indexes,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
@ -593,10 +645,12 @@ def compute_loss(
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if is_training and params.enable_distiallation:
assert codebook_loss is not None
loss += params.codebook_loss_scale * codebook_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
@ -607,6 +661,8 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if is_training and params.enable_distiallation:
info["codebook_loss"] = codebook_loss.detach().cpu().item()
return loss, info

View File

@ -0,0 +1,393 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import glob
import logging
import os
from functools import cached_property
from pathlib import Path
from typing import List, Tuple
import numpy as np
import torch
import torch.multiprocessing as mp
import quantization
from asr_datamodule import LibriSpeechAsrDataModule
from hubert_xlarge import HubertXlargeFineTuned
from icefall.utils import (
AttributeDict,
setup_logger,
)
from lhotse import CutSet, load_manifest
from lhotse.features.io import NumpyHdf5Writer
class CodebookIndexExtractor:
"""
A wrapper of quantiation.Quantizer.
It's responsible for:
1. extract and save activations from a teacher model.
2. train quantizer from previous activations.
3. extract codebook indexes for whole training set.
Normally this step needs multi GPUs.
"""
def __init__(self, params: AttributeDict):
self.params = params
params.subsets = ["clean-100"]
if self.params.full_libri:
self.params.subsets += ["clean-360", "other-500"]
self.init_dirs()
setup_logger(f"{self.vq_dir}/log-vq_extraction")
def init_dirs(self):
# vq_dir is the root dir for quantizer:
# training data/ quantizer / extracted codebook indexes
self.vq_dir = (
self.params.exp_dir / f"vq/{self.params.teacher_model_id}/"
)
self.vq_dir.mkdir(parents=True, exist_ok=True)
# manifest_dir for :
# splited original manifests,
# extracted codebook indexes and their related manifests
self.manifest_dir = self.vq_dir / f"splits{self.params.world_size}"
self.manifest_dir.mkdir(parents=True, exist_ok=True)
self.ori_manifest_dir = "./data/fbank/"
self.dst_manifest_dir = "./data/vq_fbank/"
self.dst_manifest_dir.mkdir(parents=True, exist_ok=True)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
# Options about teacher embeddings eatraction.
parser.add_argument(
"--embedding-layer",
type=int,
help="layer to extract teacher embeddings, 1-based.",
default=36,
)
parser.add_argument(
"--num-utts",
type=int,
default=1000,
help="num utts to train quantizer",
)
parser.add_argument(
"--num-codebooks",
type=int,
default=8,
help="""number of codebooks,
i.e. number of codebook indexes each teacher embedding is compressed.
""",
)
@property
def embedding_file_path(self):
"""
The saved embedding is used to train quantizer.
"""
embedding_file_id = (
f"num_utts_{self.params.num_utts}"
+ f"-layer_{self.params.embedding_layer}"
+ "-embedding_embeddings.h5"
)
embedding_file_path = self.vq_dir / embedding_file_id
return embedding_file_path
@torch.no_grad()
def extract_and_save_embedding(self):
"""
The extract embedding is used to train quantizer.
"""
if self.embedding_file_path.exists():
warn_message = (
f"{self.embedding_file_path} already exists."
+ " Skip extracting embeddings from teacher model"
)
logging.warn(warn_message)
return
total_cuts = 0
with NumpyHdf5Writer(self.embedding_file_path) as writer:
for batch_idx, batch in enumerate(self.quantizer_train_dl):
cut_list = batch["supervisions"]["cut"]
(
encoder_embedding,
num_frames,
) = self.teacher_model.extract_embedding(batch)
encoder_embedding = encoder_embedding.cpu().numpy()
for idx, cut in enumerate(cut_list):
cut.encoder_embedding = writer.store_array(
key=cut.id,
value=encoder_embedding[idx][: num_frames[idx]],
)
total_cuts += len(cut_list)
logging.info(
f"Processed {total_cuts} output of {self.params.num_utts} cuts."
)
logging.info(f"Processed all {total_cuts} cuts.")
@property
def quantizer_train_dl(self):
# used to train quantizer.
librispeech = LibriSpeechAsrDataModule(self.params)
quantizer_trian_cuts = librispeech.train_clean_100_cuts().subset(
first=self.params.num_utts
)
return librispeech.train_dataloaders(quantizer_trian_cuts)
@cached_property
def quantizer_file_path(self):
quantizer_file_id = (
f"num_utts-{self.params.num_utts}"
+ f"-layer-{self.params.embedding_layer}"
+ f"-num_codebooks_{self.params.num_codebooks}"
+ "-quantizer.pt"
)
quantizer_file_path = Path(self.vq_dir) / quantizer_file_id
return quantizer_file_path
def train_quantizer(self):
if self.quantizer_file_path.exists():
warn_message = (
f"{self.quantizer_file_path} already exists."
+ " Skip trainning quantizer."
)
logging.warn(warn_message)
return
assert self.embedding_file_path.exists()
trainer = quantization.QuantizerTrainer(
dim=self.params.embedding_dim,
bytes_per_frame=self.params.num_codebooks,
device=self.params.device,
)
train, valid = quantization.read_hdf5_data(self.embedding_file_path)
B = 512 # Minibatch size, this is very arbitrary, it's close to what we used
# when we tuned this method.
def minibatch_generator(data: torch.Tensor, repeat: bool):
assert 3 * B < data.shape[0]
cur_offset = 0
while True if repeat else cur_offset + B <= data.shape[0]:
start = cur_offset % (data.shape[0] + 1 - B)
end = start + B
cur_offset += B
yield data[start:end, :].to(self.params.device).to(
dtype=torch.float
)
for x in minibatch_generator(train, repeat=True):
trainer.step(x)
if trainer.done():
break
quantizer = trainer.get_quantizer()
torch.save(quantizer.state_dict(), self.quantizer_file_path)
def split_ori_manifests(self):
"""
When multi gpus are available, split original manifests
and extract codebook indexes in a prallel way.
"""
for subset in self.params.subsets:
logging.info(f"About to split {subset}.")
ori_manifest = f"./data/fbank/cuts_train-{subset}.json.gz"
split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}"
os.system(f"{split_cmd}")
def merge_vq_manifests(self):
"""
Merge generated vq included manfiests and storage to self.dst_manifest_dir.
"""
for subset in self.params.subsets:
vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-cuts_train-{subset}*.json.gz"
dst_vq_manifest = (
self.dst_manifest_dir / f"cuts_train-{subset}.json.gz"
)
if 1 == self.params.world_size:
merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}"
else:
merge_cmd = f"lhotse combine {vq_manifests} {dst_vq_manifest}"
os.system(f"{merge_cmd}")
def reuse_manifests(self):
"""
Only train-* subsets are extracted codebook indexes from.
The reset subsets are just a link from ./data/fbank.
"""
def is_train(manifest: str) -> bool:
for train_subset in ["clean-100", "clean-360", "other-500"]:
if train_subset in manifest:
return True
return False
reusable_manifests = [
manifest
for manifest in glob.glob(f"{self.ori_manifest_dir}/*.gz")
if not is_train(manifest)
]
for manifest_path in reusable_manifests:
ori_manifest_path = Path(manifest_path).resolve()
dst_manifest_path = Path(
manifest_path.replace(
self.ori_manifest_dir, self.dst_manifest_dir
)
).resolve()
if not dst_manifest_path.exists():
os.symlink(ori_manifest_path, dst_manifest_path)
def create_vq_fbank(self):
self.reuse_manifests()
self.merge_vq_manifests()
@cached_property
def teacher_model(self):
return HubertXlargeFineTuned(self.params)
@cached_property
def quantizer(self):
assert self.quantizer_file_path.exists()
quantizer = quantization.Quantizer(
dim=self.params.embedding_dim,
num_codebooks=self.params.num_codebooks,
codebook_size=256,
)
quantizer.load_state_dict(torch.load(self.quantizer_file_path))
quantizer.to(self.params.device)
return quantizer
def load_ori_dl(self, subset):
if self.params.world_size == 1:
ori_manifest_path = f"./data/fbank/cuts_train-{subset}.json.gz"
else:
ori_manifest_path = (
self.manifest_dir
/ f"cuts_train-{subset}.{self.params.manifest_index}.json.gz"
)
cuts = load_manifest(ori_manifest_path)
dl = LibriSpeechAsrDataModule(self.params).train_dataloaders(cuts)
return dl
def _release_gpu_memory(self):
self.__dict__.pop("teacher_model", None)
self.__dict__.pop("quantizer", None)
torch.cuda.empty_cache()
def extract_codebook_indexes(self):
if self.params.world_size == 1:
self.extract_codebook_indexes_imp()
else:
# Since a new extractor will be created for each rank in
# compute_codebook_indexes_parallel, it's better to
# release the GPU memory occupied by current extractor.
self._release_gpu_memory()
# Prepare split manifests for each job.
self.split_ori_manifests()
mp.spawn(
compute_codebook_indexes_parallel,
args=(self.params,),
nprocs=self.params.world_size,
join=True,
)
self.create_vq_fbank()
@torch.no_grad()
def extract_codebook_indexes_imp(self):
for subset in self.params.subsets:
num_cuts = 0
cuts = []
if self.params.world_size == 1:
manifest_file_id = f"{subset}"
else:
manifest_file_id = f"{subset}-{self.params.manifest_index}"
manifest_file_path = self.manifest_dir / manifest_file_id
with NumpyHdf5Writer(manifest_file_path) as writer:
for batch_idx, batch in enumerate(self.load_ori_dl(subset)):
(
encoder_embedding,
num_frames,
) = self.teacher_model.extract_embedding(batch)
codebook_indexes = self.quantizer.encode(encoder_embedding)
# [N, T, C]
codebook_indexes = codebook_indexes.to("cpu").numpy()
assert np.min(codebook_indexes) >= 0
assert np.max(codebook_indexes) < 256
supervisions = batch["supervisions"]
cut_list = supervisions["cut"]
assert len(cut_list) == codebook_indexes.shape[0]
assert all(c.start == 0 for c in supervisions["cut"])
for idx, cut in enumerate(cut_list):
cut.codebook_indexes = writer.store_array(
key=cut.id,
value=codebook_indexes[idx][: num_frames[idx]],
frame_shift=0.02,
temporal_dim=0,
start=0,
)
cuts += cut_list
num_cuts += len(cut_list)
message = f"Processed {num_cuts} cuts from {subset}"
if self.params.world_size > 1:
message += f" by job {self.params.manifest_index}"
logging.info(f"{message}.")
json_file_path = (
self.manifest_dir
/ f"with_codebook_indexes-cuts_train-{manifest_file_id}.json.gz"
)
CutSet.from_cuts(cuts).to_json(json_file_path)
@torch.no_grad()
def compute_codebook_indexes_parallel(
rank: int,
params,
) -> List[Tuple[str, List[int]]]:
"""Create an extractor for each rank and extract codebook indexes parallelly.
Normally, this function is called by torch.multiprocessing
when multi GPUs are available.
"""
params = copy.deepcopy(params)
device = torch.device("cuda", rank)
params.device = device
# rank is 0-based while split manifests by "lhotse split" is 1-based.
params.manifest_index = rank + 1
extractor = CodebookIndexExtractor(params=params)
extractor.extract_codebook_indexes_imp()

View File

@ -25,7 +25,7 @@ from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
BucketingSampler,
CutConcatenate,
CutMix,
@ -34,7 +34,10 @@ from lhotse.dataset import (
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
@ -150,6 +153,12 @@ class LibriSpeechAsrDataModule:
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
@ -192,6 +201,13 @@ class LibriSpeechAsrDataModule:
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
@ -263,6 +279,7 @@ class LibriSpeechAsrDataModule:
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
@ -296,7 +313,7 @@ class LibriSpeechAsrDataModule:
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
@ -371,7 +388,7 @@ class LibriSpeechAsrDataModule:
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(

View File

@ -127,7 +127,11 @@ def setup_logger(
level = logging.CRITICAL
logging.basicConfig(
filename=log_filename, format=formatter, level=level, filemode="w"
filename=log_filename,
format=formatter,
level=level,
filemode="w",
force=True,
)
if use_console:
console = logging.StreamHandler()