concat codebook index and mid layer of student model

This commit is contained in:
Guo Liyong 2022-01-06 00:59:17 +08:00
parent 6bb949eb5a
commit 4281121aa9
4 changed files with 88 additions and 21 deletions

View File

@ -18,7 +18,7 @@
import math
import warnings
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import Tensor, nn
@ -172,7 +172,8 @@ class Conformer(Transformer):
)
def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None
self, x: Tensor, supervisions: Optional[Supervisions] = None,
mid_layer_list: List[int] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
@ -197,11 +198,22 @@ class Conformer(Transformer):
mask = encoder_padding_mask(x.size(0), supervisions)
if mask is not None:
mask = mask.to(x.device)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
if mid_layer_list is not None:
x, mid_mem_embeddings = self.encoder(x, pos_emb, src_key_padding_mask=mask, mid_layer_list=mid_layer_list) # (T, B, F)
assert len(mid_layer_list) == len(mid_mem_embeddings)
# note tenosr_a == torch.cat([x]),
# so following code also works when len(mid_mem_embeddings) == 1
mid_mem_embeddings = torch.cat(mid_mem_embeddings, dim=-1)
else:
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
if self.normalize_before:
x = self.after_norm(x)
if mid_layer_list is not None:
return x, mask, mid_mem_embeddings
return x, mask
@ -405,7 +417,7 @@ class ConformerEncoder(nn.TransformerEncoder):
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
if mod_idx in mid_layer_list:
if mid_layer_list is not None and mod_idx in mid_layer_list:
mid_mem_embeddings.append(output)
if self.norm is not None:

View File

@ -87,6 +87,13 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--concat",
type=str2bool,
default=True,
help="whether to concat codeindex.",
)
parser.add_argument(
"--predictor",
type=str,
@ -111,6 +118,23 @@ def get_parser():
""",
)
parser.add_argument(
"--warm-step",
type=int,
default=80000,
help="""warm up steps""",
)
parser.add_argument(
"--mid-layer-list",
type=str,
default=None,
help="""e.g. 6,8,10
""",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -249,7 +273,7 @@ def get_params() -> AttributeDict:
"use_double_scores": True,
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000,
# "warm_step": 80000,
"env_info": get_env_info(),
}
)
@ -379,7 +403,12 @@ def compute_loss(
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
if params.mid_layer_list is not None:
# examples of params.mid_layer_list = 6,8,10
mid_layer_list = [int(layer_idx) for layer_idx in params.mid_layer_list.split(",")]
nnet_output, encoder_memory, memory_mask, mid_mem_embeddings = model(feature, supervisions, mid_layer_list)
else:
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
@ -449,18 +478,33 @@ def compute_loss(
if "wav2vec" == params.model_id:
# frame rate of wav2vec codebooks_indices is 50
# while for conformer is 25
t_expected = encoder_memory.shape[0] * 2
assert codebook_indices.shape[1] >= t_expected
codebook_indices = codebook_indices[:, 0:t_expected:2, :]
encoder_memory = encoder_memory.transpose(0, 1) # T, N, C --> N, T, C
codebook_indices = codebook_indices.to(encoder_memory.device).long()
if not params.concat:
t_expected = encoder_memory.shape[0] * 2
assert codebook_indices.shape[1] >= t_expected
codebook_indices = codebook_indices[:, 0:t_expected:2, :]
else:
t_expected = encoder_memory.shape[0]
# import pdb; pdb.set_trace()
N, T, C = codebook_indices.shape
T = T // 2 * 2
codebook_indices = codebook_indices[:, :T, :]
codebook_indices = codebook_indices.reshape(N, T // 2, C * 2)
assert codebook_indices.shape[1] >= t_expected
codebook_indices = codebook_indices[:, 0:t_expected, :]
if params.mid_layer_list is not None:
mem_embeddings_codebook = mid_mem_embeddings.transpose(0, 1) # T, N, C --> N, T, C
else:
mem_embeddings_codebook = encoder_memory.transpose(0, 1) # T, N, C --> N, T, C
codebook_indices = codebook_indices.to(mem_embeddings_codebook.device).long()
if (
params.predictor == "ckpnt_predictor"
or params.predictor == "powerful"
):
codebook_loss = mmodel.cdidxnet(encoder_memory, codebook_indices)
codebook_loss = mmodel.cdidxnet(mem_embeddings_codebook, codebook_indices)
else:
total_logprob, _ = mmodel.cdidxnet(encoder_memory, codebook_indices)
total_logprob, _ = mmodel.cdidxnet(mem_embeddings_codebook, codebook_indices)
codebook_loss = -total_logprob
loss += params.codebook_weight * codebook_loss
@ -673,7 +717,7 @@ def run(rank, world_size, args):
vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm,
use_codebook_loss=True if params.codebook_weight > 0.0 else False,
num_codebooks=params.bytes_per_frame,
num_codebooks=params.bytes_per_frame * 2 if params.concat else params.bytes_per_frame,
predictor=params.predictor,
)
@ -792,9 +836,11 @@ def main():
if 0.0 != args.codebook_weight:
assert -1 == args.time_warp_factor
assert not args.exp_dir.endswith("/")
args.exp_dir = Path(
args.exp_dir = \
f"{args.exp_dir}-time_warp_factor{args.time_warp_factor}-bytes_per_frame{args.bytes_per_frame}-cdweight{args.codebook_weight}-predictor{args.predictor}-maxduration{args.max_duration}" # noqa: E501
)
if args.mid_layer_list is not None:
args.exp_dir = args.exp_dir + f"-mid_layer_list{args.mid_layer_list}"
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
world_size = args.world_size

View File

@ -158,7 +158,8 @@ class Transformer(nn.Module):
self.decoder_criterion = None
def forward(
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
self, x: torch.Tensor, supervision: Optional[Supervisions] = None,
mid_layer_list: List[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
@ -183,10 +184,18 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
if mid_layer_list is not None:
encoder_memory, memory_key_padding_mask, mid_mem_embeddings = self.run_encoder(
x, supervision, mid_layer_list,
)
else:
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision,
)
x = self.ctc_output(encoder_memory)
if mid_layer_list is not None:
return x, encoder_memory, memory_key_padding_mask, mid_mem_embeddings
return x, encoder_memory, memory_key_padding_mask
def run_encoder(

View File

@ -75,7 +75,7 @@ class LibriSpeechAsrDataModule(DataModule):
)
parser.add_argument(
"--subset",
type=Path,
type=str,
default=None,
help="which subset to extract codebook index"
"clean-100, clean-360, other-500",