mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
concat codebook index and mid layer of student model
This commit is contained in:
parent
6bb949eb5a
commit
4281121aa9
@ -18,7 +18,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -172,7 +172,8 @@ class Conformer(Transformer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def run_encoder(
|
def run_encoder(
|
||||||
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
self, x: Tensor, supervisions: Optional[Supervisions] = None,
|
||||||
|
mid_layer_list: List[int] = None,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -197,11 +198,22 @@ class Conformer(Transformer):
|
|||||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
mask = encoder_padding_mask(x.size(0), supervisions)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask.to(x.device)
|
mask = mask.to(x.device)
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
if mid_layer_list is not None:
|
||||||
|
x, mid_mem_embeddings = self.encoder(x, pos_emb, src_key_padding_mask=mask, mid_layer_list=mid_layer_list) # (T, B, F)
|
||||||
|
assert len(mid_layer_list) == len(mid_mem_embeddings)
|
||||||
|
# note tenosr_a == torch.cat([x]),
|
||||||
|
# so following code also works when len(mid_mem_embeddings) == 1
|
||||||
|
mid_mem_embeddings = torch.cat(mid_mem_embeddings, dim=-1)
|
||||||
|
else:
|
||||||
|
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||||
|
|
||||||
|
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
x = self.after_norm(x)
|
x = self.after_norm(x)
|
||||||
|
|
||||||
|
if mid_layer_list is not None:
|
||||||
|
return x, mask, mid_mem_embeddings
|
||||||
|
|
||||||
return x, mask
|
return x, mask
|
||||||
|
|
||||||
|
|
||||||
@ -405,7 +417,7 @@ class ConformerEncoder(nn.TransformerEncoder):
|
|||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
if mod_idx in mid_layer_list:
|
if mid_layer_list is not None and mod_idx in mid_layer_list:
|
||||||
mid_mem_embeddings.append(output)
|
mid_mem_embeddings.append(output)
|
||||||
|
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
|
@ -87,6 +87,13 @@ def get_parser():
|
|||||||
help="Should various information be logged in tensorboard.",
|
help="Should various information be logged in tensorboard.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--concat",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="whether to concat codeindex.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--predictor",
|
"--predictor",
|
||||||
type=str,
|
type=str,
|
||||||
@ -111,6 +118,23 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--warm-step",
|
||||||
|
type=int,
|
||||||
|
default=80000,
|
||||||
|
help="""warm up steps""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mid-layer-list",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="""e.g. 6,8,10
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -249,7 +273,7 @@ def get_params() -> AttributeDict:
|
|||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"weight_decay": 1e-6,
|
"weight_decay": 1e-6,
|
||||||
"warm_step": 80000,
|
# "warm_step": 80000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -379,7 +403,12 @@ def compute_loss(
|
|||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
if params.mid_layer_list is not None:
|
||||||
|
# examples of params.mid_layer_list = 6,8,10
|
||||||
|
mid_layer_list = [int(layer_idx) for layer_idx in params.mid_layer_list.split(",")]
|
||||||
|
nnet_output, encoder_memory, memory_mask, mid_mem_embeddings = model(feature, supervisions, mid_layer_list)
|
||||||
|
else:
|
||||||
|
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||||
# nnet_output is (N, T, C)
|
# nnet_output is (N, T, C)
|
||||||
|
|
||||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||||
@ -449,18 +478,33 @@ def compute_loss(
|
|||||||
if "wav2vec" == params.model_id:
|
if "wav2vec" == params.model_id:
|
||||||
# frame rate of wav2vec codebooks_indices is 50
|
# frame rate of wav2vec codebooks_indices is 50
|
||||||
# while for conformer is 25
|
# while for conformer is 25
|
||||||
t_expected = encoder_memory.shape[0] * 2
|
if not params.concat:
|
||||||
assert codebook_indices.shape[1] >= t_expected
|
t_expected = encoder_memory.shape[0] * 2
|
||||||
codebook_indices = codebook_indices[:, 0:t_expected:2, :]
|
assert codebook_indices.shape[1] >= t_expected
|
||||||
encoder_memory = encoder_memory.transpose(0, 1) # T, N, C --> N, T, C
|
codebook_indices = codebook_indices[:, 0:t_expected:2, :]
|
||||||
codebook_indices = codebook_indices.to(encoder_memory.device).long()
|
else:
|
||||||
|
t_expected = encoder_memory.shape[0]
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
N, T, C = codebook_indices.shape
|
||||||
|
|
||||||
|
T = T // 2 * 2
|
||||||
|
codebook_indices = codebook_indices[:, :T, :]
|
||||||
|
|
||||||
|
codebook_indices = codebook_indices.reshape(N, T // 2, C * 2)
|
||||||
|
assert codebook_indices.shape[1] >= t_expected
|
||||||
|
codebook_indices = codebook_indices[:, 0:t_expected, :]
|
||||||
|
if params.mid_layer_list is not None:
|
||||||
|
mem_embeddings_codebook = mid_mem_embeddings.transpose(0, 1) # T, N, C --> N, T, C
|
||||||
|
else:
|
||||||
|
mem_embeddings_codebook = encoder_memory.transpose(0, 1) # T, N, C --> N, T, C
|
||||||
|
codebook_indices = codebook_indices.to(mem_embeddings_codebook.device).long()
|
||||||
if (
|
if (
|
||||||
params.predictor == "ckpnt_predictor"
|
params.predictor == "ckpnt_predictor"
|
||||||
or params.predictor == "powerful"
|
or params.predictor == "powerful"
|
||||||
):
|
):
|
||||||
codebook_loss = mmodel.cdidxnet(encoder_memory, codebook_indices)
|
codebook_loss = mmodel.cdidxnet(mem_embeddings_codebook, codebook_indices)
|
||||||
else:
|
else:
|
||||||
total_logprob, _ = mmodel.cdidxnet(encoder_memory, codebook_indices)
|
total_logprob, _ = mmodel.cdidxnet(mem_embeddings_codebook, codebook_indices)
|
||||||
codebook_loss = -total_logprob
|
codebook_loss = -total_logprob
|
||||||
|
|
||||||
loss += params.codebook_weight * codebook_loss
|
loss += params.codebook_weight * codebook_loss
|
||||||
@ -673,7 +717,7 @@ def run(rank, world_size, args):
|
|||||||
vgg_frontend=False,
|
vgg_frontend=False,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||||
use_codebook_loss=True if params.codebook_weight > 0.0 else False,
|
use_codebook_loss=True if params.codebook_weight > 0.0 else False,
|
||||||
num_codebooks=params.bytes_per_frame,
|
num_codebooks=params.bytes_per_frame * 2 if params.concat else params.bytes_per_frame,
|
||||||
predictor=params.predictor,
|
predictor=params.predictor,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -792,9 +836,11 @@ def main():
|
|||||||
if 0.0 != args.codebook_weight:
|
if 0.0 != args.codebook_weight:
|
||||||
assert -1 == args.time_warp_factor
|
assert -1 == args.time_warp_factor
|
||||||
assert not args.exp_dir.endswith("/")
|
assert not args.exp_dir.endswith("/")
|
||||||
args.exp_dir = Path(
|
args.exp_dir = \
|
||||||
f"{args.exp_dir}-time_warp_factor{args.time_warp_factor}-bytes_per_frame{args.bytes_per_frame}-cdweight{args.codebook_weight}-predictor{args.predictor}-maxduration{args.max_duration}" # noqa: E501
|
f"{args.exp_dir}-time_warp_factor{args.time_warp_factor}-bytes_per_frame{args.bytes_per_frame}-cdweight{args.codebook_weight}-predictor{args.predictor}-maxduration{args.max_duration}" # noqa: E501
|
||||||
)
|
if args.mid_layer_list is not None:
|
||||||
|
args.exp_dir = args.exp_dir + f"-mid_layer_list{args.mid_layer_list}"
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
|
@ -158,7 +158,8 @@ class Transformer(nn.Module):
|
|||||||
self.decoder_criterion = None
|
self.decoder_criterion = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
|
self, x: torch.Tensor, supervision: Optional[Supervisions] = None,
|
||||||
|
mid_layer_list: List[int] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -183,10 +184,18 @@ class Transformer(nn.Module):
|
|||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||||
x = self.feat_batchnorm(x)
|
x = self.feat_batchnorm(x)
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
if mid_layer_list is not None:
|
||||||
x, supervision
|
encoder_memory, memory_key_padding_mask, mid_mem_embeddings = self.run_encoder(
|
||||||
)
|
x, supervision, mid_layer_list,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||||
|
x, supervision,
|
||||||
|
)
|
||||||
x = self.ctc_output(encoder_memory)
|
x = self.ctc_output(encoder_memory)
|
||||||
|
|
||||||
|
if mid_layer_list is not None:
|
||||||
|
return x, encoder_memory, memory_key_padding_mask, mid_mem_embeddings
|
||||||
return x, encoder_memory, memory_key_padding_mask
|
return x, encoder_memory, memory_key_padding_mask
|
||||||
|
|
||||||
def run_encoder(
|
def run_encoder(
|
||||||
|
@ -75,7 +75,7 @@ class LibriSpeechAsrDataModule(DataModule):
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--subset",
|
"--subset",
|
||||||
type=Path,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="which subset to extract codebook index"
|
help="which subset to extract codebook index"
|
||||||
"clean-100, clean-360, other-500",
|
"clean-100, clean-360, other-500",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user