mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 19:12:30 +00:00
support MVQ KD training in conv_emformer
This commit is contained in:
parent
aa0fe4e4ac
commit
e91fbef939
@ -1133,7 +1133,10 @@ class EmformerEncoder(nn.Module):
|
||||
tanh_on_mem (bool, optional):
|
||||
If ``true``, applies tanh to memory elements. (default: ``false``)
|
||||
negative_inf (float, optional):
|
||||
Value to use for negative infinity in attention weights. (default: -1e8)
|
||||
Value to use for negative infinity in attention weights. (default: -1e8),
|
||||
output_layers:
|
||||
A list of integers containing the id of emformer layers whose activations
|
||||
will be returned
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -1151,6 +1154,7 @@ class EmformerEncoder(nn.Module):
|
||||
memory_size: int = 0,
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
output_layers: List[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1188,6 +1192,7 @@ class EmformerEncoder(nn.Module):
|
||||
self.chunk_length = chunk_length
|
||||
self.memory_size = memory_size
|
||||
self.cnn_module_kernel = cnn_module_kernel
|
||||
self.output_layers = output_layers
|
||||
|
||||
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Hard copy each chunk's right context and concat them."""
|
||||
@ -1366,7 +1371,8 @@ class EmformerEncoder(nn.Module):
|
||||
padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
|
||||
|
||||
output = utterance
|
||||
for layer in self.emformer_layers:
|
||||
layer_results = []
|
||||
for layer_index, layer in enumerate(self.emformer_layers):
|
||||
output, right_context = layer(
|
||||
output,
|
||||
right_context,
|
||||
@ -1374,8 +1380,11 @@ class EmformerEncoder(nn.Module):
|
||||
padding_mask=padding_mask,
|
||||
warmup=warmup,
|
||||
)
|
||||
if layer_index in self.output_layers:
|
||||
# (T, N, C) --> (N, T, C)
|
||||
layer_results.append(output.permute(1, 0, 2))
|
||||
|
||||
return output, output_lengths
|
||||
return layer_results, output_lengths
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
@ -1545,6 +1554,7 @@ class Emformer(EncoderInterface):
|
||||
memory_size: int = 0,
|
||||
tanh_on_mem: bool = False,
|
||||
negative_inf: float = -1e8,
|
||||
middle_output_layer: int = None, # 0-based layer index
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1573,6 +1583,17 @@ class Emformer(EncoderInterface):
|
||||
# (2) embedding: num_features -> d_model
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
output_layers = []
|
||||
if middle_output_layer is not None:
|
||||
assert (
|
||||
middle_output_layer >= 0
|
||||
and middle_output_layer < num_encoder_layers
|
||||
), f"Invalid middle output layer"
|
||||
output_layers.append(middle_output_layer)
|
||||
|
||||
# The last layer is always needed.
|
||||
output_layers.append(num_encoder_layers - 1)
|
||||
|
||||
self.encoder = EmformerEncoder(
|
||||
chunk_length=chunk_length // subsampling_factor,
|
||||
d_model=d_model,
|
||||
@ -1587,7 +1608,8 @@ class Emformer(EncoderInterface):
|
||||
memory_size=memory_size,
|
||||
tanh_on_mem=tanh_on_mem,
|
||||
negative_inf=negative_inf,
|
||||
)
|
||||
output_layers=output_layers, # for distillation
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
@ -1624,9 +1646,7 @@ class Emformer(EncoderInterface):
|
||||
x_lens = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
assert x.size(0) == x_lens.max().item()
|
||||
|
||||
output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C)
|
||||
|
||||
output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (N, T, C)
|
||||
|
||||
return output, output_lengths
|
||||
|
||||
|
@ -74,7 +74,8 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from emformer import Emformer
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.cut import Cut, MonoCut
|
||||
from lhotse.dataset.collation import collate_custom_field
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
@ -357,6 +358,41 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-distillation",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to eanble distillation.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--distillation-layer",
|
||||
type=int,
|
||||
default=8,
|
||||
help="On which encoder layer to perform KD"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-codebooks",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of codebooks"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--distil-delta",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Offset when doing KD"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--codebook-loss-scale",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="The scale of codebook loss.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -446,6 +482,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
left_context_length=params.left_context_length,
|
||||
right_context_length=params.right_context_length,
|
||||
memory_size=params.memory_size,
|
||||
middle_output_layer=params.distillation_layer
|
||||
if params.enable_distillation
|
||||
else None,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -483,6 +522,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
num_codebooks=params.num_codebooks if params.enable_distillation else 0,
|
||||
distil_delta=params.distil_delta if params.enable_distillation else 0,
|
||||
)
|
||||
return model
|
||||
|
||||
@ -605,6 +646,16 @@ def save_checkpoint(
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
def extract_codebook_indexes(batch):
|
||||
cuts = batch["supervisions"]["cut"]
|
||||
# -100 is identical to ignore_value in CE loss computation.
|
||||
cuts_pre_mixed = [
|
||||
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
|
||||
]
|
||||
codebook_indexes, codebook_indexes_lens = collate_custom_field(
|
||||
cuts_pre_mixed, "codebook_indexes", pad_value=-100
|
||||
)
|
||||
return codebook_indexes, codebook_indexes_lens
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
@ -645,8 +696,14 @@ def compute_loss(
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
if is_training and params.enable_distillation:
|
||||
codebook_indexes, _ = extract_codebook_indexes(batch)
|
||||
codebook_indexes = codebook_indexes.to(device)
|
||||
else:
|
||||
codebook_indexes = None
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
simple_loss, pruned_loss, codebook_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -654,6 +711,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
codebook_indexes=codebook_indexes,
|
||||
)
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (model_warm_step), to avoid
|
||||
@ -664,6 +722,10 @@ def compute_loss(
|
||||
)
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
if is_training and params.enable_distillation:
|
||||
assert codebook_loss is not None
|
||||
loss += params.codebook_loss_scale * codebook_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
@ -684,6 +746,8 @@ def compute_loss(
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
if is_training and params.enable_distillation:
|
||||
info["codebook_loss"] = codebook_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
# 2022 Xiaomi Corp. (authors: Zengwei Yao, Liyong Guo, Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -40,6 +41,8 @@ class Transducer(nn.Module):
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
num_codebooks: int = 0,
|
||||
distil_delta: int=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -68,6 +71,16 @@ class Transducer(nn.Module):
|
||||
|
||||
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
|
||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
||||
|
||||
from multi_quantization.prediction import JointCodebookLoss
|
||||
self.distil_delta = distil_delta
|
||||
|
||||
if num_codebooks > 0:
|
||||
self.codebook_loss_net = JointCodebookLoss(
|
||||
predictor_channels=encoder_dim,
|
||||
num_codebooks=num_codebooks,
|
||||
is_joint=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -80,6 +93,7 @@ class Transducer(nn.Module):
|
||||
warmup: float = 1.0,
|
||||
reduction: str = "sum",
|
||||
delay_penalty: float = 0.0,
|
||||
codebook_indexes: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -112,6 +126,8 @@ class Transducer(nn.Module):
|
||||
streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
||||
codebook_indexes:
|
||||
codebook_indexes extracted from a teacher model.
|
||||
Returns:
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
@ -129,7 +145,35 @@ class Transducer(nn.Module):
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||
layer_results, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||
encoder_out = layer_results[-1] # the last item is the final output
|
||||
|
||||
middle_layer_output = layer_results[0]
|
||||
if self.training and codebook_indexes is not None:
|
||||
assert hasattr(self, "codebook_loss_net")
|
||||
# due to different subsampling ratio between hubert teacher and emformer
|
||||
if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
|
||||
codebook_indexes = self.concat_successive_codebook_indexes(
|
||||
middle_layer_output, codebook_indexes
|
||||
)
|
||||
if self.distil_delta is not None:
|
||||
N = codebook_indexes.shape[0]
|
||||
T = codebook_indexes.shape[1]
|
||||
cur_distil_delta = self.distil_delta
|
||||
# align (teacher) with (student + self.distill_delta)
|
||||
# suppose self.distil_delta == 2
|
||||
unvalid_teacher_mask = codebook_indexes == -100
|
||||
# 1,2,3,4,5,6,7,8,-100,-100 --> 1,2,1,2,3,4,5,6,7,8
|
||||
codebook_indexes[:, cur_distil_delta:, :] = codebook_indexes.clone()[:, :T-cur_distil_delta, :]
|
||||
unvalid_teacher_mask[:, :cur_distil_delta] = True
|
||||
codebook_indexes.masked_fill_(unvalid_teacher_mask, -100)
|
||||
# --> -100, -100, 1,2,3,4,5,6,-100,-100
|
||||
codebook_loss = self.codebook_loss_net(
|
||||
middle_layer_output, codebook_indexes
|
||||
)
|
||||
else:
|
||||
# when codebook index is not available.
|
||||
codebook_loss = None
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
@ -204,4 +248,32 @@ class Transducer(nn.Module):
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
return (simple_loss, pruned_loss, codebook_loss)
|
||||
|
||||
@staticmethod
|
||||
def concat_successive_codebook_indexes(
|
||||
middle_layer_output, codebook_indexes
|
||||
):
|
||||
# Output rate of hubert is 50 frames per second,
|
||||
# while that of current encoder is 25.
|
||||
# Following code handling two issues:
|
||||
# 1.
|
||||
# Roughly speaking, to generate another frame output,
|
||||
# hubert needes extra two frames,
|
||||
# while current encoder needs extra four frames.
|
||||
# Suppose there are only extra three frames provided,
|
||||
# hubert will generate another frame while current encoder does nothing.
|
||||
# 2.
|
||||
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
|
||||
# learns from 50 frames teacher output, two successive frames of teacher model
|
||||
# output is concatenated together.
|
||||
t_expected = middle_layer_output.shape[1]
|
||||
N, T, C = codebook_indexes.shape
|
||||
assert T >= t_expected, (T, t_expected)
|
||||
# Handling issue 1.
|
||||
if T >= t_expected * 2:
|
||||
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
|
||||
# Handling issue 2.
|
||||
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
|
||||
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
|
||||
return codebook_indexes
|
Loading…
x
Reference in New Issue
Block a user