mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
concat codebook index and mid layer of student model
This commit is contained in:
parent
6bb949eb5a
commit
4281121aa9
@ -18,7 +18,7 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@ -172,7 +172,8 @@ class Conformer(Transformer):
|
||||
)
|
||||
|
||||
def run_encoder(
|
||||
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
||||
self, x: Tensor, supervisions: Optional[Supervisions] = None,
|
||||
mid_layer_list: List[int] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
@ -197,11 +198,22 @@ class Conformer(Transformer):
|
||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
||||
if mask is not None:
|
||||
mask = mask.to(x.device)
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||
if mid_layer_list is not None:
|
||||
x, mid_mem_embeddings = self.encoder(x, pos_emb, src_key_padding_mask=mask, mid_layer_list=mid_layer_list) # (T, B, F)
|
||||
assert len(mid_layer_list) == len(mid_mem_embeddings)
|
||||
# note tenosr_a == torch.cat([x]),
|
||||
# so following code also works when len(mid_mem_embeddings) == 1
|
||||
mid_mem_embeddings = torch.cat(mid_mem_embeddings, dim=-1)
|
||||
else:
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
|
||||
if mid_layer_list is not None:
|
||||
return x, mask, mid_mem_embeddings
|
||||
|
||||
return x, mask
|
||||
|
||||
|
||||
@ -405,7 +417,7 @@ class ConformerEncoder(nn.TransformerEncoder):
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
if mod_idx in mid_layer_list:
|
||||
if mid_layer_list is not None and mod_idx in mid_layer_list:
|
||||
mid_mem_embeddings.append(output)
|
||||
|
||||
if self.norm is not None:
|
||||
|
@ -87,6 +87,13 @@ def get_parser():
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--concat",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether to concat codeindex.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--predictor",
|
||||
type=str,
|
||||
@ -111,6 +118,23 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warm-step",
|
||||
type=int,
|
||||
default=80000,
|
||||
help="""warm up steps""",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--mid-layer-list",
|
||||
type=str,
|
||||
default=None,
|
||||
help="""e.g. 6,8,10
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -249,7 +273,7 @@ def get_params() -> AttributeDict:
|
||||
"use_double_scores": True,
|
||||
# parameters for Noam
|
||||
"weight_decay": 1e-6,
|
||||
"warm_step": 80000,
|
||||
# "warm_step": 80000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -379,7 +403,12 @@ def compute_loss(
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
if params.mid_layer_list is not None:
|
||||
# examples of params.mid_layer_list = 6,8,10
|
||||
mid_layer_list = [int(layer_idx) for layer_idx in params.mid_layer_list.split(",")]
|
||||
nnet_output, encoder_memory, memory_mask, mid_mem_embeddings = model(feature, supervisions, mid_layer_list)
|
||||
else:
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
@ -449,18 +478,33 @@ def compute_loss(
|
||||
if "wav2vec" == params.model_id:
|
||||
# frame rate of wav2vec codebooks_indices is 50
|
||||
# while for conformer is 25
|
||||
t_expected = encoder_memory.shape[0] * 2
|
||||
assert codebook_indices.shape[1] >= t_expected
|
||||
codebook_indices = codebook_indices[:, 0:t_expected:2, :]
|
||||
encoder_memory = encoder_memory.transpose(0, 1) # T, N, C --> N, T, C
|
||||
codebook_indices = codebook_indices.to(encoder_memory.device).long()
|
||||
if not params.concat:
|
||||
t_expected = encoder_memory.shape[0] * 2
|
||||
assert codebook_indices.shape[1] >= t_expected
|
||||
codebook_indices = codebook_indices[:, 0:t_expected:2, :]
|
||||
else:
|
||||
t_expected = encoder_memory.shape[0]
|
||||
# import pdb; pdb.set_trace()
|
||||
N, T, C = codebook_indices.shape
|
||||
|
||||
T = T // 2 * 2
|
||||
codebook_indices = codebook_indices[:, :T, :]
|
||||
|
||||
codebook_indices = codebook_indices.reshape(N, T // 2, C * 2)
|
||||
assert codebook_indices.shape[1] >= t_expected
|
||||
codebook_indices = codebook_indices[:, 0:t_expected, :]
|
||||
if params.mid_layer_list is not None:
|
||||
mem_embeddings_codebook = mid_mem_embeddings.transpose(0, 1) # T, N, C --> N, T, C
|
||||
else:
|
||||
mem_embeddings_codebook = encoder_memory.transpose(0, 1) # T, N, C --> N, T, C
|
||||
codebook_indices = codebook_indices.to(mem_embeddings_codebook.device).long()
|
||||
if (
|
||||
params.predictor == "ckpnt_predictor"
|
||||
or params.predictor == "powerful"
|
||||
):
|
||||
codebook_loss = mmodel.cdidxnet(encoder_memory, codebook_indices)
|
||||
codebook_loss = mmodel.cdidxnet(mem_embeddings_codebook, codebook_indices)
|
||||
else:
|
||||
total_logprob, _ = mmodel.cdidxnet(encoder_memory, codebook_indices)
|
||||
total_logprob, _ = mmodel.cdidxnet(mem_embeddings_codebook, codebook_indices)
|
||||
codebook_loss = -total_logprob
|
||||
|
||||
loss += params.codebook_weight * codebook_loss
|
||||
@ -673,7 +717,7 @@ def run(rank, world_size, args):
|
||||
vgg_frontend=False,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
use_codebook_loss=True if params.codebook_weight > 0.0 else False,
|
||||
num_codebooks=params.bytes_per_frame,
|
||||
num_codebooks=params.bytes_per_frame * 2 if params.concat else params.bytes_per_frame,
|
||||
predictor=params.predictor,
|
||||
)
|
||||
|
||||
@ -792,9 +836,11 @@ def main():
|
||||
if 0.0 != args.codebook_weight:
|
||||
assert -1 == args.time_warp_factor
|
||||
assert not args.exp_dir.endswith("/")
|
||||
args.exp_dir = Path(
|
||||
args.exp_dir = \
|
||||
f"{args.exp_dir}-time_warp_factor{args.time_warp_factor}-bytes_per_frame{args.bytes_per_frame}-cdweight{args.codebook_weight}-predictor{args.predictor}-maxduration{args.max_duration}" # noqa: E501
|
||||
)
|
||||
if args.mid_layer_list is not None:
|
||||
args.exp_dir = args.exp_dir + f"-mid_layer_list{args.mid_layer_list}"
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
world_size = args.world_size
|
||||
|
@ -158,7 +158,8 @@ class Transformer(nn.Module):
|
||||
self.decoder_criterion = None
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
|
||||
self, x: torch.Tensor, supervision: Optional[Supervisions] = None,
|
||||
mid_layer_list: List[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
@ -183,10 +184,18 @@ class Transformer(nn.Module):
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||
x, supervision
|
||||
)
|
||||
if mid_layer_list is not None:
|
||||
encoder_memory, memory_key_padding_mask, mid_mem_embeddings = self.run_encoder(
|
||||
x, supervision, mid_layer_list,
|
||||
)
|
||||
else:
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||
x, supervision,
|
||||
)
|
||||
x = self.ctc_output(encoder_memory)
|
||||
|
||||
if mid_layer_list is not None:
|
||||
return x, encoder_memory, memory_key_padding_mask, mid_mem_embeddings
|
||||
return x, encoder_memory, memory_key_padding_mask
|
||||
|
||||
def run_encoder(
|
||||
|
@ -75,7 +75,7 @@ class LibriSpeechAsrDataModule(DataModule):
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subset",
|
||||
type=Path,
|
||||
type=str,
|
||||
default=None,
|
||||
help="which subset to extract codebook index"
|
||||
"clean-100, clean-360, other-500",
|
||||
|
Loading…
x
Reference in New Issue
Block a user