concat codebook index and mid layer of student model

This commit is contained in:
Guo Liyong 2022-01-06 00:59:17 +08:00
parent 6bb949eb5a
commit 4281121aa9
4 changed files with 88 additions and 21 deletions

View File

@ -18,7 +18,7 @@
import math import math
import warnings import warnings
from typing import Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -172,7 +172,8 @@ class Conformer(Transformer):
) )
def run_encoder( def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None self, x: Tensor, supervisions: Optional[Supervisions] = None,
mid_layer_list: List[int] = None,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
""" """
Args: Args:
@ -197,11 +198,22 @@ class Conformer(Transformer):
mask = encoder_padding_mask(x.size(0), supervisions) mask = encoder_padding_mask(x.size(0), supervisions)
if mask is not None: if mask is not None:
mask = mask.to(x.device) mask = mask.to(x.device)
if mid_layer_list is not None:
x, mid_mem_embeddings = self.encoder(x, pos_emb, src_key_padding_mask=mask, mid_layer_list=mid_layer_list) # (T, B, F)
assert len(mid_layer_list) == len(mid_mem_embeddings)
# note tenosr_a == torch.cat([x]),
# so following code also works when len(mid_mem_embeddings) == 1
mid_mem_embeddings = torch.cat(mid_mem_embeddings, dim=-1)
else:
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
if self.normalize_before: if self.normalize_before:
x = self.after_norm(x) x = self.after_norm(x)
if mid_layer_list is not None:
return x, mask, mid_mem_embeddings
return x, mask return x, mask
@ -405,7 +417,7 @@ class ConformerEncoder(nn.TransformerEncoder):
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
) )
if mod_idx in mid_layer_list: if mid_layer_list is not None and mod_idx in mid_layer_list:
mid_mem_embeddings.append(output) mid_mem_embeddings.append(output)
if self.norm is not None: if self.norm is not None:

View File

@ -87,6 +87,13 @@ def get_parser():
help="Should various information be logged in tensorboard.", help="Should various information be logged in tensorboard.",
) )
parser.add_argument(
"--concat",
type=str2bool,
default=True,
help="whether to concat codeindex.",
)
parser.add_argument( parser.add_argument(
"--predictor", "--predictor",
type=str, type=str,
@ -111,6 +118,23 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--warm-step",
type=int,
default=80000,
help="""warm up steps""",
)
parser.add_argument(
"--mid-layer-list",
type=str,
default=None,
help="""e.g. 6,8,10
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -249,7 +273,7 @@ def get_params() -> AttributeDict:
"use_double_scores": True, "use_double_scores": True,
# parameters for Noam # parameters for Noam
"weight_decay": 1e-6, "weight_decay": 1e-6,
"warm_step": 80000, # "warm_step": 80000,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -379,6 +403,11 @@ def compute_loss(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
if params.mid_layer_list is not None:
# examples of params.mid_layer_list = 6,8,10
mid_layer_list = [int(layer_idx) for layer_idx in params.mid_layer_list.split(",")]
nnet_output, encoder_memory, memory_mask, mid_mem_embeddings = model(feature, supervisions, mid_layer_list)
else:
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is (N, T, C) # nnet_output is (N, T, C)
@ -449,18 +478,33 @@ def compute_loss(
if "wav2vec" == params.model_id: if "wav2vec" == params.model_id:
# frame rate of wav2vec codebooks_indices is 50 # frame rate of wav2vec codebooks_indices is 50
# while for conformer is 25 # while for conformer is 25
if not params.concat:
t_expected = encoder_memory.shape[0] * 2 t_expected = encoder_memory.shape[0] * 2
assert codebook_indices.shape[1] >= t_expected assert codebook_indices.shape[1] >= t_expected
codebook_indices = codebook_indices[:, 0:t_expected:2, :] codebook_indices = codebook_indices[:, 0:t_expected:2, :]
encoder_memory = encoder_memory.transpose(0, 1) # T, N, C --> N, T, C else:
codebook_indices = codebook_indices.to(encoder_memory.device).long() t_expected = encoder_memory.shape[0]
# import pdb; pdb.set_trace()
N, T, C = codebook_indices.shape
T = T // 2 * 2
codebook_indices = codebook_indices[:, :T, :]
codebook_indices = codebook_indices.reshape(N, T // 2, C * 2)
assert codebook_indices.shape[1] >= t_expected
codebook_indices = codebook_indices[:, 0:t_expected, :]
if params.mid_layer_list is not None:
mem_embeddings_codebook = mid_mem_embeddings.transpose(0, 1) # T, N, C --> N, T, C
else:
mem_embeddings_codebook = encoder_memory.transpose(0, 1) # T, N, C --> N, T, C
codebook_indices = codebook_indices.to(mem_embeddings_codebook.device).long()
if ( if (
params.predictor == "ckpnt_predictor" params.predictor == "ckpnt_predictor"
or params.predictor == "powerful" or params.predictor == "powerful"
): ):
codebook_loss = mmodel.cdidxnet(encoder_memory, codebook_indices) codebook_loss = mmodel.cdidxnet(mem_embeddings_codebook, codebook_indices)
else: else:
total_logprob, _ = mmodel.cdidxnet(encoder_memory, codebook_indices) total_logprob, _ = mmodel.cdidxnet(mem_embeddings_codebook, codebook_indices)
codebook_loss = -total_logprob codebook_loss = -total_logprob
loss += params.codebook_weight * codebook_loss loss += params.codebook_weight * codebook_loss
@ -673,7 +717,7 @@ def run(rank, world_size, args):
vgg_frontend=False, vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
use_codebook_loss=True if params.codebook_weight > 0.0 else False, use_codebook_loss=True if params.codebook_weight > 0.0 else False,
num_codebooks=params.bytes_per_frame, num_codebooks=params.bytes_per_frame * 2 if params.concat else params.bytes_per_frame,
predictor=params.predictor, predictor=params.predictor,
) )
@ -792,9 +836,11 @@ def main():
if 0.0 != args.codebook_weight: if 0.0 != args.codebook_weight:
assert -1 == args.time_warp_factor assert -1 == args.time_warp_factor
assert not args.exp_dir.endswith("/") assert not args.exp_dir.endswith("/")
args.exp_dir = Path( args.exp_dir = \
f"{args.exp_dir}-time_warp_factor{args.time_warp_factor}-bytes_per_frame{args.bytes_per_frame}-cdweight{args.codebook_weight}-predictor{args.predictor}-maxduration{args.max_duration}" # noqa: E501 f"{args.exp_dir}-time_warp_factor{args.time_warp_factor}-bytes_per_frame{args.bytes_per_frame}-cdweight{args.codebook_weight}-predictor{args.predictor}-maxduration{args.max_duration}" # noqa: E501
) if args.mid_layer_list is not None:
args.exp_dir = args.exp_dir + f"-mid_layer_list{args.mid_layer_list}"
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir) args.lang_dir = Path(args.lang_dir)
world_size = args.world_size world_size = args.world_size

View File

@ -158,7 +158,8 @@ class Transformer(nn.Module):
self.decoder_criterion = None self.decoder_criterion = None
def forward( def forward(
self, x: torch.Tensor, supervision: Optional[Supervisions] = None self, x: torch.Tensor, supervision: Optional[Supervisions] = None,
mid_layer_list: List[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
""" """
Args: Args:
@ -183,10 +184,18 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
if mid_layer_list is not None:
encoder_memory, memory_key_padding_mask, mid_mem_embeddings = self.run_encoder(
x, supervision, mid_layer_list,
)
else:
encoder_memory, memory_key_padding_mask = self.run_encoder( encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision x, supervision,
) )
x = self.ctc_output(encoder_memory) x = self.ctc_output(encoder_memory)
if mid_layer_list is not None:
return x, encoder_memory, memory_key_padding_mask, mid_mem_embeddings
return x, encoder_memory, memory_key_padding_mask return x, encoder_memory, memory_key_padding_mask
def run_encoder( def run_encoder(

View File

@ -75,7 +75,7 @@ class LibriSpeechAsrDataModule(DataModule):
) )
parser.add_argument( parser.add_argument(
"--subset", "--subset",
type=Path, type=str,
default=None, default=None,
help="which subset to extract codebook index" help="which subset to extract codebook index"
"clean-100, clean-360, other-500", "clean-100, clean-360, other-500",