From 3705a58624b921f82447e2a454cb142cc48bb6e8 Mon Sep 17 00:00:00 2001 From: marcoyang1998 Date: Fri, 28 Jul 2023 14:40:34 +0800 Subject: [PATCH] support MVQ as a training option --- egs/librispeech/ASR/zipformer/model.py | 114 +++++++++++++++++++-- egs/librispeech/ASR/zipformer/train.py | 81 +++++++++++++-- egs/librispeech/ASR/zipformer/zipformer.py | 19 +++- 3 files changed, 196 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index f2f86af47..64fde095b 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -16,12 +16,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional, Tuple, List import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface +from multi_quantization.prediction import JointCodebookLoss from icefall.utils import add_sos, make_pad_mask from scaling import ScaledLinear @@ -39,12 +40,15 @@ class AsrModel(nn.Module): vocab_size: int = 500, use_transducer: bool = True, use_ctc: bool = False, + num_codebooks: int = 8, + cb_input_dim: int = 384, ): """A joint CTC & Transducer ASR model. - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + - Potentially with MVQ knowledge distillation (https://arxiv.org/abs/2211.00508) Args: encoder_embed: @@ -70,6 +74,10 @@ class AsrModel(nn.Module): Whether use transducer head. Default: True. use_ctc: Whether use CTC head. Default: False. + num_codebooks: + Greater than 0 if we want to do MVQ knowledge distillation. + cb_input_dim: + The input dimension to the codebook loss module. """ super().__init__() @@ -110,6 +118,12 @@ class AsrModel(nn.Module): nn.Linear(encoder_dim, vocab_size), nn.LogSoftmax(dim=-1), ) + + if num_codebooks > 0: + self.codebook_loss_net = JointCodebookLoss( + predictor_channels=cb_input_dim, + num_codebooks=num_codebooks, + ) def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor @@ -127,6 +141,8 @@ class AsrModel(nn.Module): Encoder output, of shape (N, T, C). encoder_out_lens: Encoder output lengths, of shape (N,). + saved_embeddings: + The embeddings from the middle layers """ # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") x, x_lens = self.encoder_embed(x, x_lens) @@ -135,12 +151,12 @@ class AsrModel(nn.Module): src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out, encoder_out_lens, middle_out = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - return encoder_out, encoder_out_lens + return encoder_out, encoder_out_lens, middle_out def forward_ctc( self, @@ -180,6 +196,7 @@ class AsrModel(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, + codebook_indexes: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute Transducer loss. Args: @@ -286,6 +303,7 @@ class AsrModel(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, + codebook_indexes: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -306,9 +324,12 @@ class AsrModel(nn.Module): lm_scale: The scale to smooth the loss with lm (output of predictor network) part + codebook_indexes: + The codebook indexes to be predicted. Only used when doing knowledge + distillation with MVQ Returns: - Return the transducer losses and CTC loss, - in form of (simple_loss, pruned_loss, ctc_loss) + Return the transducer losses and CTC loss, and potentially codebook loss + in form of (simple_loss, pruned_loss, ctc_loss, codebook_loss) Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -323,7 +344,7 @@ class AsrModel(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) # Compute encoder outputs - encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + encoder_out, encoder_out_lens, middle_out = self.forward_encoder(x, x_lens) row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] @@ -354,5 +375,84 @@ class AsrModel(nn.Module): ) else: ctc_loss = torch.empty(0) + + if self.training and hasattr(self, "codebook_loss_net"): + assert codebook_indexes is not None + codebook_loss = self.forward_codebook( + middle_out=middle_out, + codebook_indexes=codebook_indexes, + ) + else: + codebook_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss + return simple_loss, pruned_loss, ctc_loss, codebook_loss + + def forward_codebook( + self, + middle_out: List[torch.Tensor], + codebook_indexes: torch.Tensor, + ) -> torch.Tensor: + """Calculate the codebook loss for the model (knowledge distillation) + + Args: + middle_out (List[torch.Tensor]): + The embeddings extracted from the middle layer of the zipformer encoder + codebook_indexes (torch.Tensor): + The encoded codebook indexes for knowledge distillation + + Returns: + The codebook loss value + """ + middle_layer_output = middle_out[0] # currently only support using output of one layer, (N,T,C) + len_CI = codebook_indexes.size(1) + len_mid_layer = middle_layer_output.size(1) + ratio = round(len_CI/len_mid_layer) + + if ratio == 1: # Having the same frame rate + assert len_CI > len_mid_layer, (len_CI, len_mid_layer) + codebook_indexes = codebook_indexes[:, :len_mid_layer, :] + assert codebook_indexes.size(1) == middle_layer_output.size(1) + codebook_loss = self.codebook_loss_net( + middle_layer_output, codebook_indexes + ) + elif ratio == 2: + codebook_indexes = self.concat_successive_codebook_indexes( + middle_layer_output, codebook_indexes + ) + codebook_loss = self.codebook_loss_net( + middle_layer_output, codebook_indexes + ) + + return codebook_loss + + @staticmethod + def concat_successive_codebook_indexes( + middle_layer_output, codebook_indexes + ): + # Output rate of hubert is 50 frames per second, + # while that of current encoder is 25. + # Following code handling two issues: + # 1. + # Roughly speaking, to generate another frame output, + # hubert needes extra two frames, + # while current encoder needs extra four frames. + # Suppose there are only extra three frames provided, + # hubert will generate another frame while current encoder does nothing. + # 2. + # codebook loss is a frame-wise loss, to enalbe 25 frames studnet output + # learns from 50 frames teacher output, two successive frames of teacher model + # output is concatenated together. + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape + assert T >= t_expected, (T, t_expected) + # Handling issue 1. + if T >= t_expected * 2: + codebook_indexes = codebook_indexes[:, : t_expected * 2, :] + if T / t_expected < 1.1: # To be changed, dirty hack to jump out of this function + codebook_indexes = codebook_indexes[:, : t_expected, :] + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes + # Handling issue 2. + codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2) + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index bc3e9c1ba..f5142650b 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -68,7 +68,8 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from decoder import Decoder from joiner import Joiner -from lhotse.cut import Cut +from lhotse.cut import Cut, MonoCut +from lhotse.dataset.collation import collate_custom_field from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -402,6 +403,34 @@ def get_parser(): default=0.2, help="Scale for CTC loss.", ) + + parser.add_argument( + "--enable-distillation", + type=str2bool, + default=True, + help="Whether to eanble distillation.", + ) + + parser.add_argument( + "--codebook-loss-scale", + type=float, + default=0.1, + help="The scale of codebook loss.", + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=16, + help="Number of codebooks used for the extracted CI", + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=4, + help="Where to perform MVQ-KD", + ) parser.add_argument( "--seed", @@ -579,6 +608,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: causal=params.causal, chunk_size=_to_int_tuple(params.chunk_size), left_context_frames=_to_int_tuple(params.left_context_frames), + middle_output_layer=params.distillation_layer + if params.enable_distillation + else None, ) return encoder @@ -630,6 +662,8 @@ def get_model(params: AttributeDict) -> nn.Module: vocab_size=params.vocab_size, use_transducer=params.use_transducer, use_ctc=params.use_ctc, + num_codebooks=params.num_codebooks if params.enable_distillation else 0, + cb_input_dim=_to_int_tuple(params.encoder_dim)[params.distillation_layer], ) return model @@ -749,6 +783,16 @@ def save_checkpoint( best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) +def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]: + cuts = batch["supervisions"]["cut"] + # -100 is identical to ignore_value in CE loss computation. + cuts_pre_mixed = [ + c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts + ] + codebook_indexes, codebook_indexes_lens = collate_custom_field( + cuts_pre_mixed, "codebook_indexes", pad_value=-100 + ) + return codebook_indexes, codebook_indexes_lens def compute_loss( params: AttributeDict, @@ -790,15 +834,22 @@ def compute_loss( texts = batch["supervisions"]["text"] y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) + + if is_training and params.enable_distillation: + codebook_indexes, _ = extract_codebook_indexes(batch) + codebook_indexes = codebook_indexes.to(device) + else: + codebook_indexes = None with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + simple_loss, pruned_loss, ctc_loss, codebook_loss = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + codebook_indexes=codebook_indexes, ) loss = 0.0 @@ -822,6 +873,9 @@ def compute_loss( if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss + + if is_training and params.enable_distillation: + loss += params.codebook_loss_scale * codebook_loss assert loss.requires_grad == is_training @@ -837,6 +891,8 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() + if is_training and params.enable_distillation: + info["codebook_loss"] = codebook_loss.detach().cpu().item() return loss, info @@ -1105,6 +1161,11 @@ def run(rank, world_size, args): else: tb_writer = None + # Note: it's better to set --spec-aug-time-warpi-factor=-1 + # when doing distillation with vq. + if params.enable_distillation: + assert args.spec_aug_time_warp_factor < 1, "Specaug should be disabled during distillation" + device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) @@ -1234,14 +1295,14 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index b39af02b8..7897057ff 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -90,6 +90,8 @@ class Zipformer2(EncoderInterface): context chunks for causal training; will be rounded to a number of chunks. Must not be less than cnn_module_kernel (after factoring in rounding and downsampling); an error will be thrown if this is violated. + middle_output_layer: + Get the output of a middle layer of the model """ def __init__( self, @@ -110,6 +112,7 @@ class Zipformer2(EncoderInterface): causal: bool = False, chunk_size: Tuple[int] = [-1], left_context_frames: Tuple[int] = [-1], + middle_output_layer: int = None # 0-based layer index ) -> None: super(Zipformer2, self).__init__() @@ -190,6 +193,17 @@ class Zipformer2(EncoderInterface): encoders.append(encoder) self.encoders = nn.ModuleList(encoders) + + # for mvq: return the middle layer output + output_layers = [] + if middle_output_layer is not None: + assert ( + middle_output_layer >= 0 + and middle_output_layer < len(num_encoder_layers) + ) + output_layers.append(middle_output_layer) + + self.output_layers = output_layers # A list of int self.downsample_output = SimpleDownsample(max(encoder_dim), downsample=output_downsampling_factor, @@ -334,6 +348,9 @@ class Zipformer2(EncoderInterface): x = self._get_full_dim_output(outputs) x = self.downsample_output(x) # class Downsample has this rounding behavior.. + + saved = [outputs[i].permute(1,0,2) for i in self.output_layers] # collect the embeddings + assert self.output_downsampling_factor == 2, self.output_downsampling_factor if torch.jit.is_scripting() or torch.jit.is_tracing(): lengths = (x_lens + 1) // 2 @@ -342,7 +359,7 @@ class Zipformer2(EncoderInterface): warnings.simplefilter("ignore") lengths = (x_lens + 1) // 2 - return x, lengths + return x, lengths, saved def _get_attn_mask( self, x: Tensor,