support MVQ as a training option

This commit is contained in:
marcoyang1998 2023-07-28 14:40:34 +08:00
parent 19b942c958
commit 3705a58624
3 changed files with 196 additions and 18 deletions

View File

@ -16,12 +16,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import Optional, Tuple, List
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from multi_quantization.prediction import JointCodebookLoss
from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
@ -39,12 +40,15 @@ class AsrModel(nn.Module):
vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
num_codebooks: int = 8,
cb_input_dim: int = 384,
):
"""A joint CTC & Transducer ASR model.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
- Potentially with MVQ knowledge distillation (https://arxiv.org/abs/2211.00508)
Args:
encoder_embed:
@ -70,6 +74,10 @@ class AsrModel(nn.Module):
Whether use transducer head. Default: True.
use_ctc:
Whether use CTC head. Default: False.
num_codebooks:
Greater than 0 if we want to do MVQ knowledge distillation.
cb_input_dim:
The input dimension to the codebook loss module.
"""
super().__init__()
@ -110,6 +118,12 @@ class AsrModel(nn.Module):
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=cb_input_dim,
num_codebooks=num_codebooks,
)
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
@ -127,6 +141,8 @@ class AsrModel(nn.Module):
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
saved_embeddings:
The embeddings from the middle layers
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
@ -135,12 +151,12 @@ class AsrModel(nn.Module):
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out, encoder_out_lens, middle_out = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens
return encoder_out, encoder_out_lens, middle_out
def forward_ctc(
self,
@ -180,6 +196,7 @@ class AsrModel(nn.Module):
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
codebook_indexes: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
@ -286,6 +303,7 @@ class AsrModel(nn.Module):
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
codebook_indexes: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
@ -306,9 +324,12 @@ class AsrModel(nn.Module):
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
codebook_indexes:
The codebook indexes to be predicted. Only used when doing knowledge
distillation with MVQ
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
Return the transducer losses and CTC loss, and potentially codebook loss
in form of (simple_loss, pruned_loss, ctc_loss, codebook_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
@ -323,7 +344,7 @@ class AsrModel(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
encoder_out, encoder_out_lens, middle_out = self.forward_encoder(x, x_lens)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
@ -354,5 +375,84 @@ class AsrModel(nn.Module):
)
else:
ctc_loss = torch.empty(0)
if self.training and hasattr(self, "codebook_loss_net"):
assert codebook_indexes is not None
codebook_loss = self.forward_codebook(
middle_out=middle_out,
codebook_indexes=codebook_indexes,
)
else:
codebook_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss
return simple_loss, pruned_loss, ctc_loss, codebook_loss
def forward_codebook(
self,
middle_out: List[torch.Tensor],
codebook_indexes: torch.Tensor,
) -> torch.Tensor:
"""Calculate the codebook loss for the model (knowledge distillation)
Args:
middle_out (List[torch.Tensor]):
The embeddings extracted from the middle layer of the zipformer encoder
codebook_indexes (torch.Tensor):
The encoded codebook indexes for knowledge distillation
Returns:
The codebook loss value
"""
middle_layer_output = middle_out[0] # currently only support using output of one layer, (N,T,C)
len_CI = codebook_indexes.size(1)
len_mid_layer = middle_layer_output.size(1)
ratio = round(len_CI/len_mid_layer)
if ratio == 1: # Having the same frame rate
assert len_CI > len_mid_layer, (len_CI, len_mid_layer)
codebook_indexes = codebook_indexes[:, :len_mid_layer, :]
assert codebook_indexes.size(1) == middle_layer_output.size(1)
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
elif ratio == 2:
codebook_indexes = self.concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
)
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
return codebook_loss
@staticmethod
def concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
):
# Output rate of hubert is 50 frames per second,
# while that of current encoder is 25.
# Following code handling two issues:
# 1.
# Roughly speaking, to generate another frame output,
# hubert needes extra two frames,
# while current encoder needs extra four frames.
# Suppose there are only extra three frames provided,
# hubert will generate another frame while current encoder does nothing.
# 2.
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
# learns from 50 frames teacher output, two successive frames of teacher model
# output is concatenated together.
t_expected = middle_layer_output.shape[1]
N, T, C = codebook_indexes.shape
assert T >= t_expected, (T, t_expected)
# Handling issue 1.
if T >= t_expected * 2:
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
if T / t_expected < 1.1: # To be changed, dirty hack to jump out of this function
codebook_indexes = codebook_indexes[:, : t_expected, :]
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes
# Handling issue 2.
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes

View File

@ -68,7 +68,8 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.cut import Cut, MonoCut
from lhotse.dataset.collation import collate_custom_field
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import AsrModel
@ -402,6 +403,34 @@ def get_parser():
default=0.2,
help="Scale for CTC loss.",
)
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="The scale of codebook loss.",
)
parser.add_argument(
"--num-codebooks",
type=int,
default=16,
help="Number of codebooks used for the extracted CI",
)
parser.add_argument(
"--distillation-layer",
type=int,
default=4,
help="Where to perform MVQ-KD",
)
parser.add_argument(
"--seed",
@ -579,6 +608,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames),
middle_output_layer=params.distillation_layer
if params.enable_distillation
else None,
)
return encoder
@ -630,6 +662,8 @@ def get_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size,
use_transducer=params.use_transducer,
use_ctc=params.use_ctc,
num_codebooks=params.num_codebooks if params.enable_distillation else 0,
cb_input_dim=_to_int_tuple(params.encoder_dim)[params.distillation_layer],
)
return model
@ -749,6 +783,16 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]:
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
codebook_indexes, codebook_indexes_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indexes", pad_value=-100
)
return codebook_indexes, codebook_indexes_lens
def compute_loss(
params: AttributeDict,
@ -790,15 +834,22 @@ def compute_loss(
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
if is_training and params.enable_distillation:
codebook_indexes, _ = extract_codebook_indexes(batch)
codebook_indexes = codebook_indexes.to(device)
else:
codebook_indexes = None
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
simple_loss, pruned_loss, ctc_loss, codebook_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
codebook_indexes=codebook_indexes,
)
loss = 0.0
@ -822,6 +873,9 @@ def compute_loss(
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
if is_training and params.enable_distillation:
loss += params.codebook_loss_scale * codebook_loss
assert loss.requires_grad == is_training
@ -837,6 +891,8 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if is_training and params.enable_distillation:
info["codebook_loss"] = codebook_loss.detach().cpu().item()
return loss, info
@ -1105,6 +1161,11 @@ def run(rank, world_size, args):
else:
tb_writer = None
# Note: it's better to set --spec-aug-time-warpi-factor=-1
# when doing distillation with vq.
if params.enable_distillation:
assert args.spec_aug_time_warp_factor < 1, "Specaug should be disabled during distillation"
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
@ -1234,14 +1295,14 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:

View File

@ -90,6 +90,8 @@ class Zipformer2(EncoderInterface):
context chunks for causal training; will be rounded to a number of
chunks. Must not be less than cnn_module_kernel (after factoring in
rounding and downsampling); an error will be thrown if this is violated.
middle_output_layer:
Get the output of a middle layer of the model
"""
def __init__(
self,
@ -110,6 +112,7 @@ class Zipformer2(EncoderInterface):
causal: bool = False,
chunk_size: Tuple[int] = [-1],
left_context_frames: Tuple[int] = [-1],
middle_output_layer: int = None # 0-based layer index
) -> None:
super(Zipformer2, self).__init__()
@ -190,6 +193,17 @@ class Zipformer2(EncoderInterface):
encoders.append(encoder)
self.encoders = nn.ModuleList(encoders)
# for mvq: return the middle layer output
output_layers = []
if middle_output_layer is not None:
assert (
middle_output_layer >= 0
and middle_output_layer < len(num_encoder_layers)
)
output_layers.append(middle_output_layer)
self.output_layers = output_layers # A list of int
self.downsample_output = SimpleDownsample(max(encoder_dim),
downsample=output_downsampling_factor,
@ -334,6 +348,9 @@ class Zipformer2(EncoderInterface):
x = self._get_full_dim_output(outputs)
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
saved = [outputs[i].permute(1,0,2) for i in self.output_layers] # collect the embeddings
assert self.output_downsampling_factor == 2, self.output_downsampling_factor
if torch.jit.is_scripting() or torch.jit.is_tracing():
lengths = (x_lens + 1) // 2
@ -342,7 +359,7 @@ class Zipformer2(EncoderInterface):
warnings.simplefilter("ignore")
lengths = (x_lens + 1) // 2
return x, lengths
return x, lengths, saved
def _get_attn_mask(
self, x: Tensor,