support MVQ KD training in conv_emformer

This commit is contained in:
marcoyang 2022-12-29 16:01:04 +08:00
parent aa0fe4e4ac
commit e91fbef939
3 changed files with 167 additions and 11 deletions

View File

@ -1133,7 +1133,10 @@ class EmformerEncoder(nn.Module):
tanh_on_mem (bool, optional):
If ``true``, applies tanh to memory elements. (default: ``false``)
negative_inf (float, optional):
Value to use for negative infinity in attention weights. (default: -1e8)
Value to use for negative infinity in attention weights. (default: -1e8),
output_layers:
A list of integers containing the id of emformer layers whose activations
will be returned
"""
def __init__(
@ -1151,6 +1154,7 @@ class EmformerEncoder(nn.Module):
memory_size: int = 0,
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
output_layers: List[int] = None,
):
super().__init__()
@ -1188,6 +1192,7 @@ class EmformerEncoder(nn.Module):
self.chunk_length = chunk_length
self.memory_size = memory_size
self.cnn_module_kernel = cnn_module_kernel
self.output_layers = output_layers
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
"""Hard copy each chunk's right context and concat them."""
@ -1366,7 +1371,8 @@ class EmformerEncoder(nn.Module):
padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
output = utterance
for layer in self.emformer_layers:
layer_results = []
for layer_index, layer in enumerate(self.emformer_layers):
output, right_context = layer(
output,
right_context,
@ -1374,8 +1380,11 @@ class EmformerEncoder(nn.Module):
padding_mask=padding_mask,
warmup=warmup,
)
if layer_index in self.output_layers:
# (T, N, C) --> (N, T, C)
layer_results.append(output.permute(1, 0, 2))
return output, output_lengths
return layer_results, output_lengths
@torch.jit.export
def infer(
@ -1545,6 +1554,7 @@ class Emformer(EncoderInterface):
memory_size: int = 0,
tanh_on_mem: bool = False,
negative_inf: float = -1e8,
middle_output_layer: int = None, # 0-based layer index
):
super().__init__()
@ -1573,6 +1583,17 @@ class Emformer(EncoderInterface):
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
output_layers = []
if middle_output_layer is not None:
assert (
middle_output_layer >= 0
and middle_output_layer < num_encoder_layers
), f"Invalid middle output layer"
output_layers.append(middle_output_layer)
# The last layer is always needed.
output_layers.append(num_encoder_layers - 1)
self.encoder = EmformerEncoder(
chunk_length=chunk_length // subsampling_factor,
d_model=d_model,
@ -1587,7 +1608,8 @@ class Emformer(EncoderInterface):
memory_size=memory_size,
tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf,
)
output_layers=output_layers, # for distillation
)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -1624,9 +1646,7 @@ class Emformer(EncoderInterface):
x_lens = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == x_lens.max().item()
output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C)
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (N, T, C)
return output, output_lengths

View File

@ -74,7 +74,8 @@ from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from emformer import Emformer
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.cut import Cut, MonoCut
from lhotse.dataset.collation import collate_custom_field
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
@ -357,6 +358,41 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
parser.add_argument(
"--distillation-layer",
type=int,
default=8,
help="On which encoder layer to perform KD"
)
parser.add_argument(
"--num-codebooks",
type=int,
default=16,
help="Number of codebooks"
)
parser.add_argument(
"--distil-delta",
type=int,
default=None,
help="Offset when doing KD"
)
parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="The scale of codebook loss.",
)
add_model_arguments(parser)
return parser
@ -446,6 +482,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
left_context_length=params.left_context_length,
right_context_length=params.right_context_length,
memory_size=params.memory_size,
middle_output_layer=params.distillation_layer
if params.enable_distillation
else None,
)
return encoder
@ -483,6 +522,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
num_codebooks=params.num_codebooks if params.enable_distillation else 0,
distil_delta=params.distil_delta if params.enable_distillation else 0,
)
return model
@ -605,6 +646,16 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def extract_codebook_indexes(batch):
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
codebook_indexes, codebook_indexes_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indexes", pad_value=-100
)
return codebook_indexes, codebook_indexes_lens
def compute_loss(
params: AttributeDict,
@ -645,8 +696,14 @@ def compute_loss(
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
if is_training and params.enable_distillation:
codebook_indexes, _ = extract_codebook_indexes(batch)
codebook_indexes = codebook_indexes.to(device)
else:
codebook_indexes = None
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
simple_loss, pruned_loss, codebook_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -654,6 +711,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
codebook_indexes=codebook_indexes,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
@ -664,6 +722,10 @@ def compute_loss(
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if is_training and params.enable_distillation:
assert codebook_loss is not None
loss += params.codebook_loss_scale * codebook_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
@ -684,6 +746,8 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if is_training and params.enable_distillation:
info["codebook_loss"] = codebook_loss.detach().cpu().item()
return loss, info

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
# 2022 Xiaomi Corp. (authors: Zengwei Yao, Liyong Guo, Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -40,6 +41,8 @@ class Transducer(nn.Module):
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
num_codebooks: int = 0,
distil_delta: int=None,
):
"""
Args:
@ -68,6 +71,16 @@ class Transducer(nn.Module):
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
from multi_quantization.prediction import JointCodebookLoss
self.distil_delta = distil_delta
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=encoder_dim,
num_codebooks=num_codebooks,
is_joint=False,
)
def forward(
self,
@ -80,6 +93,7 @@ class Transducer(nn.Module):
warmup: float = 1.0,
reduction: str = "sum",
delay_penalty: float = 0.0,
codebook_indexes: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@ -112,6 +126,8 @@ class Transducer(nn.Module):
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
codebook_indexes:
codebook_indexes extracted from a teacher model.
Returns:
Returns:
Return the transducer loss.
@ -129,7 +145,35 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
layer_results, x_lens = self.encoder(x, x_lens, warmup=warmup)
encoder_out = layer_results[-1] # the last item is the final output
middle_layer_output = layer_results[0]
if self.training and codebook_indexes is not None:
assert hasattr(self, "codebook_loss_net")
# due to different subsampling ratio between hubert teacher and emformer
if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
codebook_indexes = self.concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
)
if self.distil_delta is not None:
N = codebook_indexes.shape[0]
T = codebook_indexes.shape[1]
cur_distil_delta = self.distil_delta
# align (teacher) with (student + self.distill_delta)
# suppose self.distil_delta == 2
unvalid_teacher_mask = codebook_indexes == -100
# 1,2,3,4,5,6,7,8,-100,-100 --> 1,2,1,2,3,4,5,6,7,8
codebook_indexes[:, cur_distil_delta:, :] = codebook_indexes.clone()[:, :T-cur_distil_delta, :]
unvalid_teacher_mask[:, :cur_distil_delta] = True
codebook_indexes.masked_fill_(unvalid_teacher_mask, -100)
# --> -100, -100, 1,2,3,4,5,6,-100,-100
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
else:
# when codebook index is not available.
codebook_loss = None
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
@ -204,4 +248,32 @@ class Transducer(nn.Module):
reduction=reduction,
)
return (simple_loss, pruned_loss)
return (simple_loss, pruned_loss, codebook_loss)
@staticmethod
def concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
):
# Output rate of hubert is 50 frames per second,
# while that of current encoder is 25.
# Following code handling two issues:
# 1.
# Roughly speaking, to generate another frame output,
# hubert needes extra two frames,
# while current encoder needs extra four frames.
# Suppose there are only extra three frames provided,
# hubert will generate another frame while current encoder does nothing.
# 2.
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
# learns from 50 frames teacher output, two successive frames of teacher model
# output is concatenated together.
t_expected = middle_layer_output.shape[1]
N, T, C = codebook_indexes.shape
assert T >= t_expected, (T, t_expected)
# Handling issue 1.
if T >= t_expected * 2:
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
# Handling issue 2.
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes