support MVQ KD training in conv_emformer

This commit is contained in:
marcoyang 2022-12-29 16:01:04 +08:00
parent aa0fe4e4ac
commit e91fbef939
3 changed files with 167 additions and 11 deletions

View File

@ -1133,7 +1133,10 @@ class EmformerEncoder(nn.Module):
tanh_on_mem (bool, optional): tanh_on_mem (bool, optional):
If ``true``, applies tanh to memory elements. (default: ``false``) If ``true``, applies tanh to memory elements. (default: ``false``)
negative_inf (float, optional): negative_inf (float, optional):
Value to use for negative infinity in attention weights. (default: -1e8) Value to use for negative infinity in attention weights. (default: -1e8),
output_layers:
A list of integers containing the id of emformer layers whose activations
will be returned
""" """
def __init__( def __init__(
@ -1151,6 +1154,7 @@ class EmformerEncoder(nn.Module):
memory_size: int = 0, memory_size: int = 0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
output_layers: List[int] = None,
): ):
super().__init__() super().__init__()
@ -1188,6 +1192,7 @@ class EmformerEncoder(nn.Module):
self.chunk_length = chunk_length self.chunk_length = chunk_length
self.memory_size = memory_size self.memory_size = memory_size
self.cnn_module_kernel = cnn_module_kernel self.cnn_module_kernel = cnn_module_kernel
self.output_layers = output_layers
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
"""Hard copy each chunk's right context and concat them.""" """Hard copy each chunk's right context and concat them."""
@ -1366,7 +1371,8 @@ class EmformerEncoder(nn.Module):
padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths) padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
output = utterance output = utterance
for layer in self.emformer_layers: layer_results = []
for layer_index, layer in enumerate(self.emformer_layers):
output, right_context = layer( output, right_context = layer(
output, output,
right_context, right_context,
@ -1374,8 +1380,11 @@ class EmformerEncoder(nn.Module):
padding_mask=padding_mask, padding_mask=padding_mask,
warmup=warmup, warmup=warmup,
) )
if layer_index in self.output_layers:
# (T, N, C) --> (N, T, C)
layer_results.append(output.permute(1, 0, 2))
return output, output_lengths return layer_results, output_lengths
@torch.jit.export @torch.jit.export
def infer( def infer(
@ -1545,6 +1554,7 @@ class Emformer(EncoderInterface):
memory_size: int = 0, memory_size: int = 0,
tanh_on_mem: bool = False, tanh_on_mem: bool = False,
negative_inf: float = -1e8, negative_inf: float = -1e8,
middle_output_layer: int = None, # 0-based layer index
): ):
super().__init__() super().__init__()
@ -1573,6 +1583,17 @@ class Emformer(EncoderInterface):
# (2) embedding: num_features -> d_model # (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_embed = Conv2dSubsampling(num_features, d_model)
output_layers = []
if middle_output_layer is not None:
assert (
middle_output_layer >= 0
and middle_output_layer < num_encoder_layers
), f"Invalid middle output layer"
output_layers.append(middle_output_layer)
# The last layer is always needed.
output_layers.append(num_encoder_layers - 1)
self.encoder = EmformerEncoder( self.encoder = EmformerEncoder(
chunk_length=chunk_length // subsampling_factor, chunk_length=chunk_length // subsampling_factor,
d_model=d_model, d_model=d_model,
@ -1587,7 +1608,8 @@ class Emformer(EncoderInterface):
memory_size=memory_size, memory_size=memory_size,
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
negative_inf=negative_inf, negative_inf=negative_inf,
) output_layers=output_layers, # for distillation
)
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -1624,9 +1646,7 @@ class Emformer(EncoderInterface):
x_lens = (((x_lens - 1) >> 1) - 1) >> 1 x_lens = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == x_lens.max().item() assert x.size(0) == x_lens.max().item()
output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (N, T, C)
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return output, output_lengths return output, output_lengths

View File

@ -74,7 +74,8 @@ from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder from decoder import Decoder
from emformer import Emformer from emformer import Emformer
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut, MonoCut
from lhotse.dataset.collation import collate_custom_field
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
@ -357,6 +358,41 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
parser.add_argument(
"--distillation-layer",
type=int,
default=8,
help="On which encoder layer to perform KD"
)
parser.add_argument(
"--num-codebooks",
type=int,
default=16,
help="Number of codebooks"
)
parser.add_argument(
"--distil-delta",
type=int,
default=None,
help="Offset when doing KD"
)
parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="The scale of codebook loss.",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -446,6 +482,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
left_context_length=params.left_context_length, left_context_length=params.left_context_length,
right_context_length=params.right_context_length, right_context_length=params.right_context_length,
memory_size=params.memory_size, memory_size=params.memory_size,
middle_output_layer=params.distillation_layer
if params.enable_distillation
else None,
) )
return encoder return encoder
@ -483,6 +522,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
num_codebooks=params.num_codebooks if params.enable_distillation else 0,
distil_delta=params.distil_delta if params.enable_distillation else 0,
) )
return model return model
@ -605,6 +646,16 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def extract_codebook_indexes(batch):
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
codebook_indexes, codebook_indexes_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indexes", pad_value=-100
)
return codebook_indexes, codebook_indexes_lens
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
@ -645,8 +696,14 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
if is_training and params.enable_distillation:
codebook_indexes, _ = extract_codebook_indexes(batch)
codebook_indexes = codebook_indexes.to(device)
else:
codebook_indexes = None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss, codebook_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -654,6 +711,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
codebook_indexes=codebook_indexes,
) )
# after the main warmup step, we keep pruned_loss_scale small # after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid # for the same amount of time (model_warm_step), to avoid
@ -664,6 +722,10 @@ def compute_loss(
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if is_training and params.enable_distillation:
assert codebook_loss is not None
loss += params.codebook_loss_scale * codebook_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
@ -684,6 +746,8 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if is_training and params.enable_distillation:
info["codebook_loss"] = codebook_loss.detach().cpu().item()
return loss, info return loss, info

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
# 2022 Xiaomi Corp. (authors: Zengwei Yao, Liyong Guo, Xiaoyu Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -40,6 +41,8 @@ class Transducer(nn.Module):
decoder_dim: int, decoder_dim: int,
joiner_dim: int, joiner_dim: int,
vocab_size: int, vocab_size: int,
num_codebooks: int = 0,
distil_delta: int=None,
): ):
""" """
Args: Args:
@ -68,6 +71,16 @@ class Transducer(nn.Module):
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
from multi_quantization.prediction import JointCodebookLoss
self.distil_delta = distil_delta
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=encoder_dim,
num_codebooks=num_codebooks,
is_joint=False,
)
def forward( def forward(
self, self,
@ -80,6 +93,7 @@ class Transducer(nn.Module):
warmup: float = 1.0, warmup: float = 1.0,
reduction: str = "sum", reduction: str = "sum",
delay_penalty: float = 0.0, delay_penalty: float = 0.0,
codebook_indexes: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -112,6 +126,8 @@ class Transducer(nn.Module):
streaming models to emit symbols earlier. streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details. https://arxiv.org/pdf/2211.00490.pdf for more details.
codebook_indexes:
codebook_indexes extracted from a teacher model.
Returns: Returns:
Returns: Returns:
Return the transducer loss. Return the transducer loss.
@ -129,7 +145,35 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0 assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) layer_results, x_lens = self.encoder(x, x_lens, warmup=warmup)
encoder_out = layer_results[-1] # the last item is the final output
middle_layer_output = layer_results[0]
if self.training and codebook_indexes is not None:
assert hasattr(self, "codebook_loss_net")
# due to different subsampling ratio between hubert teacher and emformer
if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
codebook_indexes = self.concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
)
if self.distil_delta is not None:
N = codebook_indexes.shape[0]
T = codebook_indexes.shape[1]
cur_distil_delta = self.distil_delta
# align (teacher) with (student + self.distill_delta)
# suppose self.distil_delta == 2
unvalid_teacher_mask = codebook_indexes == -100
# 1,2,3,4,5,6,7,8,-100,-100 --> 1,2,1,2,3,4,5,6,7,8
codebook_indexes[:, cur_distil_delta:, :] = codebook_indexes.clone()[:, :T-cur_distil_delta, :]
unvalid_teacher_mask[:, :cur_distil_delta] = True
codebook_indexes.masked_fill_(unvalid_teacher_mask, -100)
# --> -100, -100, 1,2,3,4,5,6,-100,-100
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
else:
# when codebook index is not available.
codebook_loss = None
assert torch.all(x_lens > 0) assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network # Now for the decoder, i.e., the prediction network
@ -204,4 +248,32 @@ class Transducer(nn.Module):
reduction=reduction, reduction=reduction,
) )
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss, codebook_loss)
@staticmethod
def concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
):
# Output rate of hubert is 50 frames per second,
# while that of current encoder is 25.
# Following code handling two issues:
# 1.
# Roughly speaking, to generate another frame output,
# hubert needes extra two frames,
# while current encoder needs extra four frames.
# Suppose there are only extra three frames provided,
# hubert will generate another frame while current encoder does nothing.
# 2.
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
# learns from 50 frames teacher output, two successive frames of teacher model
# output is concatenated together.
t_expected = middle_layer_output.shape[1]
N, T, C = codebook_indexes.shape
assert T >= t_expected, (T, t_expected)
# Handling issue 1.
if T >= t_expected * 2:
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
# Handling issue 2.
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes