diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index d022d463e..062854c78 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -96,12 +96,15 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + DecodingResults, + parse_hyp_and_timestamp, setup_logger, - store_transcripts, + store_transcripts_and_timestamps, str2bool, - write_error_stats, + write_error_stats_with_timestamps, ) LOG_EPS = math.log(1e-10) @@ -165,6 +168,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -237,7 +247,7 @@ def decode_one_batch( sp: spm.SentencePieceProcessor, batch: dict, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: +) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -284,10 +294,12 @@ def decode_one_batch( ) encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + if isinstance(encoder_out, list): + encoder_out = encoder_out[-1] # the last item is final output hyps = [] if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( + res = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -295,63 +307,72 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( + res = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( + res = modified_beam_search( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) - + tokens = [] + timestamps = [] for i in range(batch_size): # fmt: off encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": - hyp = greedy_search( + res = greedy_search( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + return_timestamps=True, ) elif params.decoding_method == "beam_search": - hyp = beam_search( + res = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + return_timestamps=True, ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + tokens.extend(res.tokens) + timestamps.extend(res.timestamps) + res = DecodingResults(hyps=tokens, timestamps=timestamps) + + hyps, timestamps = parse_hyp_and_timestamp( + res=res, + sp=sp, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + ) if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} + return {"greedy_search": (hyps, timestamps)} elif params.decoding_method == "fast_beam_search": return { ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" - ): hyps + ): (hyps, timestamps) } else: - return {f"beam_size_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": (hyps, timestamps)} def decode_dataset( @@ -360,7 +381,7 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: +) ->Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -378,9 +399,12 @@ def decode_dataset( Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. + Its value is a list of tuples. Each tuple contains five elements: + - cut_id + - reference transcript + - predicted result + - timestamp of reference transcript + - timestamp of predicted result """ num_cuts = 0 @@ -390,14 +414,26 @@ def decode_dataset( num_batches = "?" if params.decoding_method == "greedy_search": - log_interval = 100 + log_interval = 50 else: - log_interval = 2 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + timestamps_ref = [] + for cut in batch["supervisions"]["cut"]: + for s in cut.supervisions: + time = [] + if s.alignment is not None and "word" in s.alignment: + time = [ + aliword.start + for aliword in s.alignment["word"] + if aliword.symbol != "" + ] + timestamps_ref.append(time) hyps_dict = decode_one_batch( params=params, @@ -407,12 +443,16 @@ def decode_dataset( batch=batch, ) - for name, hyps in hyps_dict.items(): + for name, (hyps, timestamps_hyp) in hyps_dict.items(): this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + assert len(hyps) == len(texts) and len(timestamps_hyp) == len( + timestamps_ref + ) + for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip( + cut_ids, hyps, texts, timestamps_hyp, timestamps_ref + ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -428,23 +468,28 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[ + str, + List[Tuple[List[str], List[str], List[str], List[float], List[float]]], + ], ): test_set_wers = dict() + test_set_delays = dict() for key, results in results_dict.items(): recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: - wer = write_error_stats( + wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer + test_set_delays[key] = (mean_delay, var_delay) logging.info("Wrote detailed error stats to {}".format(errs_filename)) @@ -455,6 +500,19 @@ def save_results( for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) + test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) + delays_info = ( + params.res_dir + / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(delays_info, "w") as f: + print("settings\tsymbol-delay", file=f) + for key, val in test_set_delays: + print( + "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), + file=f, + ) + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: @@ -462,6 +520,13 @@ def save_results( note = "" logging.info(s) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_delays: + s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) + note = "" + logging.info(s) + @torch.no_grad() def main(): @@ -511,7 +576,7 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and is defined in local/train_bpe_model.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -580,9 +645,9 @@ def main(): ) ) else: - assert params.avg > 0 + assert params.avg > 0, params.avg start = params.epoch - params.avg - assert start >= 1 + assert start >= 1, start filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( @@ -606,6 +671,9 @@ def main(): else: decoding_graph = None + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index d9ef5a2da..cc593ca6b 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -1133,7 +1133,10 @@ class EmformerEncoder(nn.Module): tanh_on_mem (bool, optional): If ``true``, applies tanh to memory elements. (default: ``false``) negative_inf (float, optional): - Value to use for negative infinity in attention weights. (default: -1e8) + Value to use for negative infinity in attention weights. (default: -1e8), + output_layers: + A list of integers containing the id of emformer layers whose activations + will be returned """ def __init__( @@ -1151,6 +1154,7 @@ class EmformerEncoder(nn.Module): memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + output_layers: List[int] = None, ): super().__init__() @@ -1188,6 +1192,7 @@ class EmformerEncoder(nn.Module): self.chunk_length = chunk_length self.memory_size = memory_size self.cnn_module_kernel = cnn_module_kernel + self.output_layers = output_layers def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" @@ -1361,7 +1366,8 @@ class EmformerEncoder(nn.Module): padding_mask = make_pad_mask(attention_mask.shape[1] - U + output_lengths) output = utterance - for layer in self.emformer_layers: + layer_results = [] + for layer_index, layer in enumerate(self.emformer_layers): output, right_context = layer( output, right_context, @@ -1369,8 +1375,11 @@ class EmformerEncoder(nn.Module): padding_mask=padding_mask, warmup=warmup, ) + if layer_index in self.output_layers: + # (T, N, C) --> (N, T, C) + layer_results.append(output.permute(1, 0, 2)) - return output, output_lengths + return layer_results, output_lengths @torch.jit.export def infer( @@ -1540,6 +1549,7 @@ class Emformer(EncoderInterface): memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + middle_output_layer: int = None, # 0-based layer index ): super().__init__() @@ -1568,6 +1578,17 @@ class Emformer(EncoderInterface): # (2) embedding: num_features -> d_model self.encoder_embed = Conv2dSubsampling(num_features, d_model) + output_layers = [] + if middle_output_layer is not None: + assert ( + middle_output_layer >= 0 + and middle_output_layer < num_encoder_layers + ), f"Invalid middle output layer" + output_layers.append(middle_output_layer) + + # The last layer is always needed. + output_layers.append(num_encoder_layers - 1) + self.encoder = EmformerEncoder( chunk_length=chunk_length // subsampling_factor, d_model=d_model, @@ -1582,7 +1603,8 @@ class Emformer(EncoderInterface): memory_size=memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, - ) + output_layers=output_layers, # for distillation + ) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -1619,9 +1641,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) - - output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (N, T, C) return output, output_lengths diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 23ddb6bec..f7d456ea4 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -74,7 +74,8 @@ from asr_datamodule import LibriSpeechAsrDataModule from decoder import Decoder from emformer import Emformer from joiner import Joiner -from lhotse.cut import Cut +from lhotse.cut import Cut, MonoCut +from lhotse.dataset.collation import collate_custom_field from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer @@ -165,6 +166,41 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Number of entries in the memory for the Emformer", ) + parser.add_argument( + "--enable-distillation", + type=str2bool, + default=True, + help="Whether to eanble distillation.", + ) + + parser.add_argument( + "--distillation-layer", + type=int, + default=8, + help="On which encoder layer to perform KD" + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=16, + help="Number of codebooks" + ) + + # distillation related args + parser.add_argument( + "--distil-delta", + type=int, + default=None, + help="Offset when doing KD" + ) + + parser.add_argument( + "--codebook-loss-scale", + type=float, + default=0.1, + help="The scale of codebook loss.", + ) def get_parser(): parser = argparse.ArgumentParser( @@ -408,6 +444,7 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { + "frame_shift_ms": 10.0, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -446,6 +483,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: left_context_length=params.left_context_length, right_context_length=params.right_context_length, memory_size=params.memory_size, + middle_output_layer=params.distillation_layer + if params.enable_distillation + else None, ) return encoder @@ -483,6 +523,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, + num_codebooks=params.num_codebooks if params.enable_distillation else 0, + distil_delta=params.distil_delta if params.enable_distillation else 0, ) return model @@ -602,6 +644,19 @@ def save_checkpoint( best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) +def extract_codebook_indexes(batch): + cuts = batch["supervisions"]["cut"] + # -100 is identical to ignore_value in CE loss computation. + cuts_pre_mixed = [ + c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts + ] + for cut in cuts_pre_mixed: + cb = cut.codebook_indexes + print(f"All cuts have codebook indexes") + codebook_indexes, codebook_indexes_lens = collate_custom_field( + cuts_pre_mixed, "codebook_indexes", pad_value=-100 + ) + return codebook_indexes, codebook_indexes_lens def compute_loss( params: AttributeDict, @@ -642,8 +697,14 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) + if is_training and params.enable_distillation: + codebook_indexes, _ = extract_codebook_indexes(batch) + codebook_indexes = codebook_indexes.to(device) + else: + codebook_indexes = None + with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, codebook_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -651,6 +712,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, + codebook_indexes=codebook_indexes, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid @@ -661,6 +723,10 @@ def compute_loss( ) loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + if is_training and params.enable_distillation: + assert codebook_loss is not None + loss += params.codebook_loss_scale * codebook_loss + assert loss.requires_grad == is_training info = MetricsTracker() @@ -681,6 +747,8 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + if is_training and params.enable_distillation: + info["codebook_loss"] = codebook_loss.detach().cpu().item() return loss, info @@ -894,6 +962,11 @@ def run(rank, world_size, args): setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") + # Note: it's better to set --spec-aug-time-warpi-factor=-1 + # when doing distillation with vq. + if params.enable_distillation: + assert args.spec_aug_time_warp_factor < 1, "You need to disable time warp in MVQ KD" + if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: @@ -959,10 +1032,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) + train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts = librispeech.train_all_shuf_cuts() - else: - train_cuts = librispeech.train_clean_100_cuts() + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -992,14 +1065,14 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 272d06c37..cd8fd0223 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -1,4 +1,5 @@ # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# 2022 Xiaomi Corp. (authors: Zengwei Yao, Liyong Guo, Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -40,6 +41,8 @@ class Transducer(nn.Module): decoder_dim: int, joiner_dim: int, vocab_size: int, + num_codebooks: int = 0, + distil_delta: int=None, ): """ Args: @@ -68,6 +71,16 @@ class Transducer(nn.Module): self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + from multi_quantization.prediction import JointCodebookLoss + self.distil_delta = distil_delta + + if num_codebooks > 0: + self.codebook_loss_net = JointCodebookLoss( + predictor_channels=encoder_dim, + num_codebooks=num_codebooks, + is_joint=False, + ) def forward( self, @@ -80,6 +93,7 @@ class Transducer(nn.Module): warmup: float = 1.0, reduction: str = "sum", delay_penalty: float = 0.0, + codebook_indexes: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -112,6 +126,8 @@ class Transducer(nn.Module): streaming models to emit symbols earlier. See https://github.com/k2-fsa/k2/issues/955 and https://arxiv.org/pdf/2211.00490.pdf for more details. + codebook_indexes: + codebook_indexes extracted from a teacher model. Returns: Returns: Return the transducer loss. @@ -129,7 +145,35 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) + layer_results, x_lens = self.encoder(x, x_lens, warmup=warmup) + encoder_out = layer_results[-1] # the last item is the final output + + middle_layer_output = layer_results[0] + if self.training and codebook_indexes is not None: + assert hasattr(self, "codebook_loss_net") + # due to different subsampling ratio between hubert teacher and emformer + if codebook_indexes.shape[1] != middle_layer_output.shape[1]: + codebook_indexes = self.concat_successive_codebook_indexes( + middle_layer_output, codebook_indexes + ) + if self.distil_delta is not None: + N = codebook_indexes.shape[0] + T = codebook_indexes.shape[1] + cur_distil_delta = self.distil_delta + # align (teacher) with (student + self.distill_delta) + # suppose self.distil_delta == 2 + unvalid_teacher_mask = codebook_indexes == -100 + # 1,2,3,4,5,6,7,8,-100,-100 --> 1,2,1,2,3,4,5,6,7,8 + codebook_indexes[:, cur_distil_delta:, :] = codebook_indexes.clone()[:, :T-cur_distil_delta, :] + unvalid_teacher_mask[:, :cur_distil_delta] = True + codebook_indexes.masked_fill_(unvalid_teacher_mask, -100) + # --> -100, -100, 1,2,3,4,5,6,-100,-100 + codebook_loss = self.codebook_loss_net( + middle_layer_output, codebook_indexes + ) + else: + # when codebook index is not available. + codebook_loss = None assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network @@ -204,4 +248,32 @@ class Transducer(nn.Module): reduction=reduction, ) - return (simple_loss, pruned_loss) + return (simple_loss, pruned_loss, codebook_loss) + + @staticmethod + def concat_successive_codebook_indexes( + middle_layer_output, codebook_indexes + ): + # Output rate of hubert is 50 frames per second, + # while that of current encoder is 25. + # Following code handling two issues: + # 1. + # Roughly speaking, to generate another frame output, + # hubert needes extra two frames, + # while current encoder needs extra four frames. + # Suppose there are only extra three frames provided, + # hubert will generate another frame while current encoder does nothing. + # 2. + # codebook loss is a frame-wise loss, to enalbe 25 frames studnet output + # learns from 50 frames teacher output, two successive frames of teacher model + # output is concatenated together. + t_expected = middle_layer_output.shape[1] + N, T, C = codebook_indexes.shape + assert T >= t_expected, (T, t_expected) + # Handling issue 1. + if T >= t_expected * 2: + codebook_indexes = codebook_indexes[:, : t_expected * 2, :] + # Handling issue 2. + codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2) + assert middle_layer_output.shape[1] == codebook_indexes.shape[1] + return codebook_indexes \ No newline at end of file