support MVQ as a training option

This commit is contained in:
marcoyang1998 2023-07-28 14:40:34 +08:00
parent 19b942c958
commit 3705a58624
3 changed files with 196 additions and 18 deletions

View File

@ -16,12 +16,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple from typing import Optional, Tuple, List
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from multi_quantization.prediction import JointCodebookLoss
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear from scaling import ScaledLinear
@ -39,12 +40,15 @@ class AsrModel(nn.Module):
vocab_size: int = 500, vocab_size: int = 500,
use_transducer: bool = True, use_transducer: bool = True,
use_ctc: bool = False, use_ctc: bool = False,
num_codebooks: int = 8,
cb_input_dim: int = 384,
): ):
"""A joint CTC & Transducer ASR model. """A joint CTC & Transducer ASR model.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
- Potentially with MVQ knowledge distillation (https://arxiv.org/abs/2211.00508)
Args: Args:
encoder_embed: encoder_embed:
@ -70,6 +74,10 @@ class AsrModel(nn.Module):
Whether use transducer head. Default: True. Whether use transducer head. Default: True.
use_ctc: use_ctc:
Whether use CTC head. Default: False. Whether use CTC head. Default: False.
num_codebooks:
Greater than 0 if we want to do MVQ knowledge distillation.
cb_input_dim:
The input dimension to the codebook loss module.
""" """
super().__init__() super().__init__()
@ -111,6 +119,12 @@ class AsrModel(nn.Module):
nn.LogSoftmax(dim=-1), nn.LogSoftmax(dim=-1),
) )
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=cb_input_dim,
num_codebooks=num_codebooks,
)
def forward_encoder( def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -127,6 +141,8 @@ class AsrModel(nn.Module):
Encoder output, of shape (N, T, C). Encoder output, of shape (N, T, C).
encoder_out_lens: encoder_out_lens:
Encoder output lengths, of shape (N,). Encoder output lengths, of shape (N,).
saved_embeddings:
The embeddings from the middle layers
""" """
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens) x, x_lens = self.encoder_embed(x, x_lens)
@ -135,12 +151,12 @@ class AsrModel(nn.Module):
src_key_padding_mask = make_pad_mask(x_lens) src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out, encoder_out_lens, middle_out = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens return encoder_out, encoder_out_lens, middle_out
def forward_ctc( def forward_ctc(
self, self,
@ -180,6 +196,7 @@ class AsrModel(nn.Module):
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
codebook_indexes: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss. """Compute Transducer loss.
Args: Args:
@ -286,6 +303,7 @@ class AsrModel(nn.Module):
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
codebook_indexes: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -306,9 +324,12 @@ class AsrModel(nn.Module):
lm_scale: lm_scale:
The scale to smooth the loss with lm (output of predictor network) The scale to smooth the loss with lm (output of predictor network)
part part
codebook_indexes:
The codebook indexes to be predicted. Only used when doing knowledge
distillation with MVQ
Returns: Returns:
Return the transducer losses and CTC loss, Return the transducer losses and CTC loss, and potentially codebook loss
in form of (simple_loss, pruned_loss, ctc_loss) in form of (simple_loss, pruned_loss, ctc_loss, codebook_loss)
Note: Note:
Regarding am_scale & lm_scale, it will make the loss-function one of Regarding am_scale & lm_scale, it will make the loss-function one of
@ -323,7 +344,7 @@ class AsrModel(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
# Compute encoder outputs # Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) encoder_out, encoder_out_lens, middle_out = self.forward_encoder(x, x_lens)
row_splits = y.shape.row_splits(1) row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1] y_lens = row_splits[1:] - row_splits[:-1]
@ -355,4 +376,83 @@ class AsrModel(nn.Module):
else: else:
ctc_loss = torch.empty(0) ctc_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss if self.training and hasattr(self, "codebook_loss_net"):
assert codebook_indexes is not None
codebook_loss = self.forward_codebook(
middle_out=middle_out,
codebook_indexes=codebook_indexes,
)
else:
codebook_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, codebook_loss
def forward_codebook(
self,
middle_out: List[torch.Tensor],
codebook_indexes: torch.Tensor,
) -> torch.Tensor:
"""Calculate the codebook loss for the model (knowledge distillation)
Args:
middle_out (List[torch.Tensor]):
The embeddings extracted from the middle layer of the zipformer encoder
codebook_indexes (torch.Tensor):
The encoded codebook indexes for knowledge distillation
Returns:
The codebook loss value
"""
middle_layer_output = middle_out[0] # currently only support using output of one layer, (N,T,C)
len_CI = codebook_indexes.size(1)
len_mid_layer = middle_layer_output.size(1)
ratio = round(len_CI/len_mid_layer)
if ratio == 1: # Having the same frame rate
assert len_CI > len_mid_layer, (len_CI, len_mid_layer)
codebook_indexes = codebook_indexes[:, :len_mid_layer, :]
assert codebook_indexes.size(1) == middle_layer_output.size(1)
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
elif ratio == 2:
codebook_indexes = self.concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
)
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
return codebook_loss
@staticmethod
def concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
):
# Output rate of hubert is 50 frames per second,
# while that of current encoder is 25.
# Following code handling two issues:
# 1.
# Roughly speaking, to generate another frame output,
# hubert needes extra two frames,
# while current encoder needs extra four frames.
# Suppose there are only extra three frames provided,
# hubert will generate another frame while current encoder does nothing.
# 2.
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
# learns from 50 frames teacher output, two successive frames of teacher model
# output is concatenated together.
t_expected = middle_layer_output.shape[1]
N, T, C = codebook_indexes.shape
assert T >= t_expected, (T, t_expected)
# Handling issue 1.
if T >= t_expected * 2:
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
if T / t_expected < 1.1: # To be changed, dirty hack to jump out of this function
codebook_indexes = codebook_indexes[:, : t_expected, :]
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes
# Handling issue 2.
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes

View File

@ -68,7 +68,8 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut, MonoCut
from lhotse.dataset.collation import collate_custom_field
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
@ -403,6 +404,34 @@ def get_parser():
help="Scale for CTC loss.", help="Scale for CTC loss.",
) )
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="The scale of codebook loss.",
)
parser.add_argument(
"--num-codebooks",
type=int,
default=16,
help="Number of codebooks used for the extracted CI",
)
parser.add_argument(
"--distillation-layer",
type=int,
default=4,
help="Where to perform MVQ-KD",
)
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
@ -579,6 +608,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
causal=params.causal, causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size), chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames), left_context_frames=_to_int_tuple(params.left_context_frames),
middle_output_layer=params.distillation_layer
if params.enable_distillation
else None,
) )
return encoder return encoder
@ -630,6 +662,8 @@ def get_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
use_transducer=params.use_transducer, use_transducer=params.use_transducer,
use_ctc=params.use_ctc, use_ctc=params.use_ctc,
num_codebooks=params.num_codebooks if params.enable_distillation else 0,
cb_input_dim=_to_int_tuple(params.encoder_dim)[params.distillation_layer],
) )
return model return model
@ -749,6 +783,16 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]:
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
codebook_indexes, codebook_indexes_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indexes", pad_value=-100
)
return codebook_indexes, codebook_indexes_lens
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
@ -791,14 +835,21 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
if is_training and params.enable_distillation:
codebook_indexes, _ = extract_codebook_indexes(batch)
codebook_indexes = codebook_indexes.to(device)
else:
codebook_indexes = None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( simple_loss, pruned_loss, ctc_loss, codebook_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
codebook_indexes=codebook_indexes,
) )
loss = 0.0 loss = 0.0
@ -823,6 +874,9 @@ def compute_loss(
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss
if is_training and params.enable_distillation:
loss += params.codebook_loss_scale * codebook_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
@ -837,6 +891,8 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc: if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item() info["ctc_loss"] = ctc_loss.detach().cpu().item()
if is_training and params.enable_distillation:
info["codebook_loss"] = codebook_loss.detach().cpu().item()
return loss, info return loss, info
@ -1105,6 +1161,11 @@ def run(rank, world_size, args):
else: else:
tb_writer = None tb_writer = None
# Note: it's better to set --spec-aug-time-warpi-factor=-1
# when doing distillation with vq.
if params.enable_distillation:
assert args.spec_aug_time_warp_factor < 1, "Specaug should be disabled during distillation"
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
@ -1234,14 +1295,14 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: # if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( # scan_pessimistic_batches_for_oom(
model=model, # model=model,
train_dl=train_dl, # train_dl=train_dl,
optimizer=optimizer, # optimizer=optimizer,
sp=sp, # sp=sp,
params=params, # params=params,
) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:

View File

@ -90,6 +90,8 @@ class Zipformer2(EncoderInterface):
context chunks for causal training; will be rounded to a number of context chunks for causal training; will be rounded to a number of
chunks. Must not be less than cnn_module_kernel (after factoring in chunks. Must not be less than cnn_module_kernel (after factoring in
rounding and downsampling); an error will be thrown if this is violated. rounding and downsampling); an error will be thrown if this is violated.
middle_output_layer:
Get the output of a middle layer of the model
""" """
def __init__( def __init__(
self, self,
@ -110,6 +112,7 @@ class Zipformer2(EncoderInterface):
causal: bool = False, causal: bool = False,
chunk_size: Tuple[int] = [-1], chunk_size: Tuple[int] = [-1],
left_context_frames: Tuple[int] = [-1], left_context_frames: Tuple[int] = [-1],
middle_output_layer: int = None # 0-based layer index
) -> None: ) -> None:
super(Zipformer2, self).__init__() super(Zipformer2, self).__init__()
@ -191,6 +194,17 @@ class Zipformer2(EncoderInterface):
self.encoders = nn.ModuleList(encoders) self.encoders = nn.ModuleList(encoders)
# for mvq: return the middle layer output
output_layers = []
if middle_output_layer is not None:
assert (
middle_output_layer >= 0
and middle_output_layer < len(num_encoder_layers)
)
output_layers.append(middle_output_layer)
self.output_layers = output_layers # A list of int
self.downsample_output = SimpleDownsample(max(encoder_dim), self.downsample_output = SimpleDownsample(max(encoder_dim),
downsample=output_downsampling_factor, downsample=output_downsampling_factor,
dropout=dropout) dropout=dropout)
@ -334,6 +348,9 @@ class Zipformer2(EncoderInterface):
x = self._get_full_dim_output(outputs) x = self._get_full_dim_output(outputs)
x = self.downsample_output(x) x = self.downsample_output(x)
# class Downsample has this rounding behavior.. # class Downsample has this rounding behavior..
saved = [outputs[i].permute(1,0,2) for i in self.output_layers] # collect the embeddings
assert self.output_downsampling_factor == 2, self.output_downsampling_factor assert self.output_downsampling_factor == 2, self.output_downsampling_factor
if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or torch.jit.is_tracing():
lengths = (x_lens + 1) // 2 lengths = (x_lens + 1) // 2
@ -342,7 +359,7 @@ class Zipformer2(EncoderInterface):
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
lengths = (x_lens + 1) // 2 lengths = (x_lens + 1) // 2
return x, lengths return x, lengths, saved
def _get_attn_mask( def _get_attn_mask(
self, x: Tensor, self, x: Tensor,