diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 7e2d7fadd..bd286642e 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -18,7 +18,7 @@ import math import warnings -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import Tensor, nn @@ -172,7 +172,8 @@ class Conformer(Transformer): ) def run_encoder( - self, x: Tensor, supervisions: Optional[Supervisions] = None + self, x: Tensor, supervisions: Optional[Supervisions] = None, + mid_layer_list: List[int] = None, ) -> Tuple[Tensor, Optional[Tensor]]: """ Args: @@ -197,11 +198,22 @@ class Conformer(Transformer): mask = encoder_padding_mask(x.size(0), supervisions) if mask is not None: mask = mask.to(x.device) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + if mid_layer_list is not None: + x, mid_mem_embeddings = self.encoder(x, pos_emb, src_key_padding_mask=mask, mid_layer_list=mid_layer_list) # (T, B, F) + assert len(mid_layer_list) == len(mid_mem_embeddings) + # note tenosr_a == torch.cat([x]), + # so following code also works when len(mid_mem_embeddings) == 1 + mid_mem_embeddings = torch.cat(mid_mem_embeddings, dim=-1) + else: + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + if self.normalize_before: x = self.after_norm(x) + if mid_layer_list is not None: + return x, mask, mid_mem_embeddings + return x, mask @@ -405,7 +417,7 @@ class ConformerEncoder(nn.TransformerEncoder): src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) - if mod_idx in mid_layer_list: + if mid_layer_list is not None and mod_idx in mid_layer_list: mid_mem_embeddings.append(output) if self.norm is not None: diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 2e7b50762..ff48882fc 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -87,6 +87,13 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--concat", + type=str2bool, + default=True, + help="whether to concat codeindex.", + ) + parser.add_argument( "--predictor", type=str, @@ -111,6 +118,23 @@ def get_parser(): """, ) + parser.add_argument( + "--warm-step", + type=int, + default=80000, + help="""warm up steps""", + ) + + + parser.add_argument( + "--mid-layer-list", + type=str, + default=None, + help="""e.g. 6,8,10 + """, + ) + + parser.add_argument( "--exp-dir", type=str, @@ -249,7 +273,7 @@ def get_params() -> AttributeDict: "use_double_scores": True, # parameters for Noam "weight_decay": 1e-6, - "warm_step": 80000, + # "warm_step": 80000, "env_info": get_env_info(), } ) @@ -379,7 +403,12 @@ def compute_loss( supervisions = batch["supervisions"] with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + if params.mid_layer_list is not None: + # examples of params.mid_layer_list = 6,8,10 + mid_layer_list = [int(layer_idx) for layer_idx in params.mid_layer_list.split(",")] + nnet_output, encoder_memory, memory_mask, mid_mem_embeddings = model(feature, supervisions, mid_layer_list) + else: + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) # nnet_output is (N, T, C) # NOTE: We need `encode_supervisions` to sort sequences with @@ -449,18 +478,33 @@ def compute_loss( if "wav2vec" == params.model_id: # frame rate of wav2vec codebooks_indices is 50 # while for conformer is 25 - t_expected = encoder_memory.shape[0] * 2 - assert codebook_indices.shape[1] >= t_expected - codebook_indices = codebook_indices[:, 0:t_expected:2, :] - encoder_memory = encoder_memory.transpose(0, 1) # T, N, C --> N, T, C - codebook_indices = codebook_indices.to(encoder_memory.device).long() + if not params.concat: + t_expected = encoder_memory.shape[0] * 2 + assert codebook_indices.shape[1] >= t_expected + codebook_indices = codebook_indices[:, 0:t_expected:2, :] + else: + t_expected = encoder_memory.shape[0] + # import pdb; pdb.set_trace() + N, T, C = codebook_indices.shape + + T = T // 2 * 2 + codebook_indices = codebook_indices[:, :T, :] + + codebook_indices = codebook_indices.reshape(N, T // 2, C * 2) + assert codebook_indices.shape[1] >= t_expected + codebook_indices = codebook_indices[:, 0:t_expected, :] + if params.mid_layer_list is not None: + mem_embeddings_codebook = mid_mem_embeddings.transpose(0, 1) # T, N, C --> N, T, C + else: + mem_embeddings_codebook = encoder_memory.transpose(0, 1) # T, N, C --> N, T, C + codebook_indices = codebook_indices.to(mem_embeddings_codebook.device).long() if ( params.predictor == "ckpnt_predictor" or params.predictor == "powerful" ): - codebook_loss = mmodel.cdidxnet(encoder_memory, codebook_indices) + codebook_loss = mmodel.cdidxnet(mem_embeddings_codebook, codebook_indices) else: - total_logprob, _ = mmodel.cdidxnet(encoder_memory, codebook_indices) + total_logprob, _ = mmodel.cdidxnet(mem_embeddings_codebook, codebook_indices) codebook_loss = -total_logprob loss += params.codebook_weight * codebook_loss @@ -673,7 +717,7 @@ def run(rank, world_size, args): vgg_frontend=False, use_feat_batchnorm=params.use_feat_batchnorm, use_codebook_loss=True if params.codebook_weight > 0.0 else False, - num_codebooks=params.bytes_per_frame, + num_codebooks=params.bytes_per_frame * 2 if params.concat else params.bytes_per_frame, predictor=params.predictor, ) @@ -792,9 +836,11 @@ def main(): if 0.0 != args.codebook_weight: assert -1 == args.time_warp_factor assert not args.exp_dir.endswith("/") - args.exp_dir = Path( + args.exp_dir = \ f"{args.exp_dir}-time_warp_factor{args.time_warp_factor}-bytes_per_frame{args.bytes_per_frame}-cdweight{args.codebook_weight}-predictor{args.predictor}-maxduration{args.max_duration}" # noqa: E501 - ) + if args.mid_layer_list is not None: + args.exp_dir = args.exp_dir + f"-mid_layer_list{args.mid_layer_list}" + args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) world_size = args.world_size diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index f93914aaa..215c1015f 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -158,7 +158,8 @@ class Transformer(nn.Module): self.decoder_criterion = None def forward( - self, x: torch.Tensor, supervision: Optional[Supervisions] = None + self, x: torch.Tensor, supervision: Optional[Supervisions] = None, + mid_layer_list: List[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Args: @@ -183,10 +184,18 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + if mid_layer_list is not None: + encoder_memory, memory_key_padding_mask, mid_mem_embeddings = self.run_encoder( + x, supervision, mid_layer_list, + ) + else: + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision, + ) x = self.ctc_output(encoder_memory) + + if mid_layer_list is not None: + return x, encoder_memory, memory_key_padding_mask, mid_mem_embeddings return x, encoder_memory, memory_key_padding_mask def run_encoder( diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index aabb2804c..7baf93f5d 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -75,7 +75,7 @@ class LibriSpeechAsrDataModule(DataModule): ) parser.add_argument( "--subset", - type=Path, + type=str, default=None, help="which subset to extract codebook index" "clean-100, clean-360, other-500",