diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index 3cedf99b6..7f34bb5c9 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -1133,7 +1133,10 @@ class EmformerEncoder(nn.Module): tanh_on_mem (bool, optional): If ``true``, applies tanh to memory elements. (default: ``false``) negative_inf (float, optional): - Value to use for negative infinity in attention weights. (default: -1e8) + Value to use for negative infinity in attention weights. (default: -1e8), + output_layers: + A list of integers containing the id of emformer layers whose activations + will be returned """ def __init__( @@ -1151,6 +1154,7 @@ class EmformerEncoder(nn.Module): memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + output_layers: List[int] = None, ): super().__init__() @@ -1188,6 +1192,7 @@ class EmformerEncoder(nn.Module): self.chunk_length = chunk_length self.memory_size = memory_size self.cnn_module_kernel = cnn_module_kernel + self.output_layers = output_layers def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" @@ -1366,7 +1371,8 @@ class EmformerEncoder(nn.Module): padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths) output = utterance - for layer in self.emformer_layers: + layer_results = [] + for layer_index, layer in enumerate(self.emformer_layers): output, right_context = layer( output, right_context, @@ -1374,8 +1380,11 @@ class EmformerEncoder(nn.Module): padding_mask=padding_mask, warmup=warmup, ) + if layer_index in self.output_layers: + # (T, N, C) --> (N, T, C) + layer_results.append(output.permute(1, 0, 2)) - return output, output_lengths + return layer_results, output_lengths @torch.jit.export def infer( @@ -1545,6 +1554,7 @@ class Emformer(EncoderInterface): memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + middle_output_layer: int = None, # 0-based layer index ): super().__init__() @@ -1573,6 +1583,17 @@ class Emformer(EncoderInterface): # (2) embedding: num_features -> d_model self.encoder_embed = Conv2dSubsampling(num_features, d_model) + output_layers = [] + if middle_output_layer is not None: + assert ( + middle_output_layer >= 0 + and middle_output_layer < num_encoder_layers + ), f"Invalid middle output layer" + output_layers.append(middle_output_layer) + + # The last layer is always needed. + output_layers.append(num_encoder_layers - 1) + self.encoder = EmformerEncoder( chunk_length=chunk_length // subsampling_factor, d_model=d_model, @@ -1587,7 +1608,8 @@ class Emformer(EncoderInterface): memory_size=memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, - ) + output_layers=output_layers, # for distillation + ) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -1624,9 +1646,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) - - output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (N, T, C) return output, output_lengths diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 8462ae92a..537d5deca 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -74,7 +74,8 @@ from asr_datamodule import LibriSpeechAsrDataModule from decoder import Decoder from emformer import Emformer from joiner import Joiner -from lhotse.cut import Cut +from lhotse.cut import Cut, MonoCut +from lhotse.dataset.collation import collate_custom_field from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer @@ -357,6 +358,41 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--enable-distillation", + type=str2bool, + default=True, + help="Whether to eanble distillation.", + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=8, + help="On which encoder layer to perform KD" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=16, + help="Number of codebooks" + ) + + parser.add_argument( + "--distil-delta", + type=int, + default=None, + help="Offset when doing KD" + ) + + parser.add_argument( + "--codebook-loss-scale", + type=float, + default=0.1, + help="The scale of codebook loss.", + ) + add_model_arguments(parser) return parser @@ -446,6 +482,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: left_context_length=params.left_context_length, right_context_length=params.right_context_length, memory_size=params.memory_size, + middle_output_layer=params.distillation_layer + if params.enable_distillation + else None, ) return encoder @@ -483,6 +522,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, + num_codebooks=params.num_codebooks if params.enable_distillation else 0, + distil_delta=params.distil_delta if params.enable_distillation else 0, ) return model @@ -605,6 +646,16 @@ def save_checkpoint( best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) +def extract_codebook_indexes(batch): + cuts = batch["supervisions"]["cut"] + # -100 is identical to ignore_value in CE loss computation. + cuts_pre_mixed = [ + c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts + ] + codebook_indexes, codebook_indexes_lens = collate_custom_field( + cuts_pre_mixed, "codebook_indexes", pad_value=-100 + ) + return codebook_indexes, codebook_indexes_lens def compute_loss( params: AttributeDict, @@ -645,8 +696,14 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) + if is_training and params.enable_distillation: + codebook_indexes, _ = extract_codebook_indexes(batch) + codebook_indexes = codebook_indexes.to(device) + else: + codebook_indexes = None + with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, codebook_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -654,6 +711,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, + codebook_indexes=codebook_indexes, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid @@ -664,6 +722,10 @@ def compute_loss( ) loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + if is_training and params.enable_distillation: + assert codebook_loss is not None + loss += params.codebook_loss_scale * codebook_loss + assert loss.requires_grad == is_training info = MetricsTracker() @@ -684,6 +746,8 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + if is_training and params.enable_distillation: + info["codebook_loss"] = codebook_loss.detach().cpu().item() return loss, info diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 272d06c37..cd8fd0223 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -1,4 +1,5 @@ # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# 2022 Xiaomi Corp. (authors: Zengwei Yao, Liyong Guo, Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -40,6 +41,8 @@ class Transducer(nn.Module): decoder_dim: int, joiner_dim: int, vocab_size: int, + num_codebooks: int = 0, + distil_delta: int=None, ): """ Args: @@ -68,6 +71,16 @@ class Transducer(nn.Module): self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + from multi_quantization.prediction import JointCodebookLoss + self.distil_delta = distil_delta + + if num_codebooks > 0: + self.codebook_loss_net = JointCodebookLoss( + predictor_channels=encoder_dim, + num_codebooks=num_codebooks, + is_joint=False, + ) def forward( self, @@ -80,6 +93,7 @@ class Transducer(nn.Module): warmup: float = 1.0, reduction: str = "sum", delay_penalty: float = 0.0, + codebook_indexes: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -112,6 +126,8 @@ class Transducer(nn.Module): streaming models to emit symbols earlier. See https://github.com/k2-fsa/k2/issues/955 and https://arxiv.org/pdf/2211.00490.pdf for more details. + codebook_indexes: + codebook_indexes extracted from a teacher model. Returns: Returns: Return the transducer loss. @@ -129,7 +145,35 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) + layer_results, x_lens = self.encoder(x, x_lens, warmup=warmup) + encoder_out = layer_results[-1] # the last item is the final output + + middle_layer_output = layer_results[0] + if self.training and codebook_indexes is not None: + assert hasattr(self, "codebook_loss_net") + # due to different subsampling ratio between hubert teacher and emformer + if codebook_indexes.shape[1] != middle_layer_output.shape[1]: + codebook_indexes = self.concat_successive_codebook_indexes( + middle_layer_output, codebook_indexes + ) + if self.distil_delta is not None: + N = codebook_indexes.shape[0] + T = codebook_indexes.shape[1] + cur_distil_delta = self.distil_delta + # align (teacher) with (student + self.distill_delta) + # suppose self.distil_delta == 2 + unvalid_teacher_mask = codebook_indexes == -100 + # 1,2,3,4,5,6,7,8,-100,-100 --> 1,2,1,2,3,4,5,6,7,8 + codebook_indexes[:, cur_distil_delta:, :] = codebook_indexes.clone()[:, :T-cur_distil_delta, :] + unvalid_teacher_mask[:, :cur_distil_delta] = True + codebook_indexes.masked_fill_(unvalid_teacher_mask, -100) + # --> -100, -100, 1,2,3,4,5,6,-100,-100 + codebook_loss = self.codebook_loss_net( + middle_layer_output, codebook_indexes + ) + else: + # when codebook index is not available. + codebook_loss = None assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network @@ -204,4 +248,32 @@ class Transducer(nn.Module): reduction=reduction, ) - return (simple_loss, pruned_loss) + return (simple_loss, pruned_loss, codebook_loss) + + @staticmethod + def concat_successive_codebook_indexes( + middle_layer_output, codebook_indexes + ): + # Output rate of hubert is 50 frames per second, + # while that of current encoder is 25. + # Following code handling two issues: + # 1. + # Roughly speaking, to generate another frame output, + # hubert needes extra two frames, + # while current encoder needs extra four frames. + # Suppose there are only extra three frames provided, + # hubert will generate another frame while current encoder does nothing. + # 2. + # codebook loss is a frame-wise loss, to enalbe 25 frames studnet output + # learns from 50 frames teacher output, two successive frames of teacher model + # output is concatenated together. + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape + assert T >= t_expected, (T, t_expected) + # Handling issue 1. + if T >= t_expected * 2: + codebook_indexes = codebook_indexes[:, : t_expected * 2, :] + # Handling issue 2. + codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2) + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes \ No newline at end of file