mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 17:44:20 +00:00
support MVQ as a training option
This commit is contained in:
parent
19b942c958
commit
3705a58624
@ -16,12 +16,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from multi_quantization.prediction import JointCodebookLoss
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import ScaledLinear
|
||||
@ -39,12 +40,15 @@ class AsrModel(nn.Module):
|
||||
vocab_size: int = 500,
|
||||
use_transducer: bool = True,
|
||||
use_ctc: bool = False,
|
||||
num_codebooks: int = 8,
|
||||
cb_input_dim: int = 384,
|
||||
):
|
||||
"""A joint CTC & Transducer ASR model.
|
||||
|
||||
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
|
||||
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
|
||||
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
|
||||
- Potentially with MVQ knowledge distillation (https://arxiv.org/abs/2211.00508)
|
||||
|
||||
Args:
|
||||
encoder_embed:
|
||||
@ -70,6 +74,10 @@ class AsrModel(nn.Module):
|
||||
Whether use transducer head. Default: True.
|
||||
use_ctc:
|
||||
Whether use CTC head. Default: False.
|
||||
num_codebooks:
|
||||
Greater than 0 if we want to do MVQ knowledge distillation.
|
||||
cb_input_dim:
|
||||
The input dimension to the codebook loss module.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -110,6 +118,12 @@ class AsrModel(nn.Module):
|
||||
nn.Linear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
if num_codebooks > 0:
|
||||
self.codebook_loss_net = JointCodebookLoss(
|
||||
predictor_channels=cb_input_dim,
|
||||
num_codebooks=num_codebooks,
|
||||
)
|
||||
|
||||
def forward_encoder(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
@ -127,6 +141,8 @@ class AsrModel(nn.Module):
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
saved_embeddings:
|
||||
The embeddings from the middle layers
|
||||
"""
|
||||
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
@ -135,12 +151,12 @@ class AsrModel(nn.Module):
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out, encoder_out_lens, middle_out = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
return encoder_out, encoder_out_lens, middle_out
|
||||
|
||||
def forward_ctc(
|
||||
self,
|
||||
@ -180,6 +196,7 @@ class AsrModel(nn.Module):
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
codebook_indexes: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute Transducer loss.
|
||||
Args:
|
||||
@ -286,6 +303,7 @@ class AsrModel(nn.Module):
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
codebook_indexes: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -306,9 +324,12 @@ class AsrModel(nn.Module):
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
codebook_indexes:
|
||||
The codebook indexes to be predicted. Only used when doing knowledge
|
||||
distillation with MVQ
|
||||
Returns:
|
||||
Return the transducer losses and CTC loss,
|
||||
in form of (simple_loss, pruned_loss, ctc_loss)
|
||||
Return the transducer losses and CTC loss, and potentially codebook loss
|
||||
in form of (simple_loss, pruned_loss, ctc_loss, codebook_loss)
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
@ -323,7 +344,7 @@ class AsrModel(nn.Module):
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||
encoder_out, encoder_out_lens, middle_out = self.forward_encoder(x, x_lens)
|
||||
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
@ -354,5 +375,84 @@ class AsrModel(nn.Module):
|
||||
)
|
||||
else:
|
||||
ctc_loss = torch.empty(0)
|
||||
|
||||
if self.training and hasattr(self, "codebook_loss_net"):
|
||||
assert codebook_indexes is not None
|
||||
codebook_loss = self.forward_codebook(
|
||||
middle_out=middle_out,
|
||||
codebook_indexes=codebook_indexes,
|
||||
)
|
||||
else:
|
||||
codebook_loss = torch.empty(0)
|
||||
|
||||
return simple_loss, pruned_loss, ctc_loss
|
||||
return simple_loss, pruned_loss, ctc_loss, codebook_loss
|
||||
|
||||
def forward_codebook(
|
||||
self,
|
||||
middle_out: List[torch.Tensor],
|
||||
codebook_indexes: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate the codebook loss for the model (knowledge distillation)
|
||||
|
||||
Args:
|
||||
middle_out (List[torch.Tensor]):
|
||||
The embeddings extracted from the middle layer of the zipformer encoder
|
||||
codebook_indexes (torch.Tensor):
|
||||
The encoded codebook indexes for knowledge distillation
|
||||
|
||||
Returns:
|
||||
The codebook loss value
|
||||
"""
|
||||
middle_layer_output = middle_out[0] # currently only support using output of one layer, (N,T,C)
|
||||
len_CI = codebook_indexes.size(1)
|
||||
len_mid_layer = middle_layer_output.size(1)
|
||||
ratio = round(len_CI/len_mid_layer)
|
||||
|
||||
if ratio == 1: # Having the same frame rate
|
||||
assert len_CI > len_mid_layer, (len_CI, len_mid_layer)
|
||||
codebook_indexes = codebook_indexes[:, :len_mid_layer, :]
|
||||
assert codebook_indexes.size(1) == middle_layer_output.size(1)
|
||||
codebook_loss = self.codebook_loss_net(
|
||||
middle_layer_output, codebook_indexes
|
||||
)
|
||||
elif ratio == 2:
|
||||
codebook_indexes = self.concat_successive_codebook_indexes(
|
||||
middle_layer_output, codebook_indexes
|
||||
)
|
||||
codebook_loss = self.codebook_loss_net(
|
||||
middle_layer_output, codebook_indexes
|
||||
)
|
||||
|
||||
return codebook_loss
|
||||
|
||||
@staticmethod
|
||||
def concat_successive_codebook_indexes(
|
||||
middle_layer_output, codebook_indexes
|
||||
):
|
||||
# Output rate of hubert is 50 frames per second,
|
||||
# while that of current encoder is 25.
|
||||
# Following code handling two issues:
|
||||
# 1.
|
||||
# Roughly speaking, to generate another frame output,
|
||||
# hubert needes extra two frames,
|
||||
# while current encoder needs extra four frames.
|
||||
# Suppose there are only extra three frames provided,
|
||||
# hubert will generate another frame while current encoder does nothing.
|
||||
# 2.
|
||||
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
|
||||
# learns from 50 frames teacher output, two successive frames of teacher model
|
||||
# output is concatenated together.
|
||||
t_expected = middle_layer_output.shape[1]
|
||||
N, T, C = codebook_indexes.shape
|
||||
assert T >= t_expected, (T, t_expected)
|
||||
# Handling issue 1.
|
||||
if T >= t_expected * 2:
|
||||
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
|
||||
if T / t_expected < 1.1: # To be changed, dirty hack to jump out of this function
|
||||
codebook_indexes = codebook_indexes[:, : t_expected, :]
|
||||
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
|
||||
return codebook_indexes
|
||||
# Handling issue 2.
|
||||
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
|
||||
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
|
||||
return codebook_indexes
|
||||
|
@ -68,7 +68,8 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.cut import Cut, MonoCut
|
||||
from lhotse.dataset.collation import collate_custom_field
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import AsrModel
|
||||
@ -402,6 +403,34 @@ def get_parser():
|
||||
default=0.2,
|
||||
help="Scale for CTC loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-distillation",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to eanble distillation.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--codebook-loss-scale",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="The scale of codebook loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-codebooks",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of codebooks used for the extracted CI",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--distillation-layer",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Where to perform MVQ-KD",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
@ -579,6 +608,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
causal=params.causal,
|
||||
chunk_size=_to_int_tuple(params.chunk_size),
|
||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||
middle_output_layer=params.distillation_layer
|
||||
if params.enable_distillation
|
||||
else None,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -630,6 +662,8 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
vocab_size=params.vocab_size,
|
||||
use_transducer=params.use_transducer,
|
||||
use_ctc=params.use_ctc,
|
||||
num_codebooks=params.num_codebooks if params.enable_distillation else 0,
|
||||
cb_input_dim=_to_int_tuple(params.encoder_dim)[params.distillation_layer],
|
||||
)
|
||||
return model
|
||||
|
||||
@ -749,6 +783,16 @@ def save_checkpoint(
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]:
|
||||
cuts = batch["supervisions"]["cut"]
|
||||
# -100 is identical to ignore_value in CE loss computation.
|
||||
cuts_pre_mixed = [
|
||||
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
|
||||
]
|
||||
codebook_indexes, codebook_indexes_lens = collate_custom_field(
|
||||
cuts_pre_mixed, "codebook_indexes", pad_value=-100
|
||||
)
|
||||
return codebook_indexes, codebook_indexes_lens
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
@ -790,15 +834,22 @@ def compute_loss(
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
if is_training and params.enable_distillation:
|
||||
codebook_indexes, _ = extract_codebook_indexes(batch)
|
||||
codebook_indexes = codebook_indexes.to(device)
|
||||
else:
|
||||
codebook_indexes = None
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
simple_loss, pruned_loss, ctc_loss, codebook_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
codebook_indexes=codebook_indexes,
|
||||
)
|
||||
|
||||
loss = 0.0
|
||||
@ -822,6 +873,9 @@ def compute_loss(
|
||||
|
||||
if params.use_ctc:
|
||||
loss += params.ctc_loss_scale * ctc_loss
|
||||
|
||||
if is_training and params.enable_distillation:
|
||||
loss += params.codebook_loss_scale * codebook_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
@ -837,6 +891,8 @@ def compute_loss(
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
if params.use_ctc:
|
||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||
if is_training and params.enable_distillation:
|
||||
info["codebook_loss"] = codebook_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -1105,6 +1161,11 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
# Note: it's better to set --spec-aug-time-warpi-factor=-1
|
||||
# when doing distillation with vq.
|
||||
if params.enable_distillation:
|
||||
assert args.spec_aug_time_warp_factor < 1, "Specaug should be disabled during distillation"
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
@ -1234,14 +1295,14 @@ def run(rank, world_size, args):
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
# if not params.print_diagnostics:
|
||||
# scan_pessimistic_batches_for_oom(
|
||||
# model=model,
|
||||
# train_dl=train_dl,
|
||||
# optimizer=optimizer,
|
||||
# sp=sp,
|
||||
# params=params,
|
||||
# )
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
|
@ -90,6 +90,8 @@ class Zipformer2(EncoderInterface):
|
||||
context chunks for causal training; will be rounded to a number of
|
||||
chunks. Must not be less than cnn_module_kernel (after factoring in
|
||||
rounding and downsampling); an error will be thrown if this is violated.
|
||||
middle_output_layer:
|
||||
Get the output of a middle layer of the model
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -110,6 +112,7 @@ class Zipformer2(EncoderInterface):
|
||||
causal: bool = False,
|
||||
chunk_size: Tuple[int] = [-1],
|
||||
left_context_frames: Tuple[int] = [-1],
|
||||
middle_output_layer: int = None # 0-based layer index
|
||||
) -> None:
|
||||
super(Zipformer2, self).__init__()
|
||||
|
||||
@ -190,6 +193,17 @@ class Zipformer2(EncoderInterface):
|
||||
encoders.append(encoder)
|
||||
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
# for mvq: return the middle layer output
|
||||
output_layers = []
|
||||
if middle_output_layer is not None:
|
||||
assert (
|
||||
middle_output_layer >= 0
|
||||
and middle_output_layer < len(num_encoder_layers)
|
||||
)
|
||||
output_layers.append(middle_output_layer)
|
||||
|
||||
self.output_layers = output_layers # A list of int
|
||||
|
||||
self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||
downsample=output_downsampling_factor,
|
||||
@ -334,6 +348,9 @@ class Zipformer2(EncoderInterface):
|
||||
x = self._get_full_dim_output(outputs)
|
||||
x = self.downsample_output(x)
|
||||
# class Downsample has this rounding behavior..
|
||||
|
||||
saved = [outputs[i].permute(1,0,2) for i in self.output_layers] # collect the embeddings
|
||||
|
||||
assert self.output_downsampling_factor == 2, self.output_downsampling_factor
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
lengths = (x_lens + 1) // 2
|
||||
@ -342,7 +359,7 @@ class Zipformer2(EncoderInterface):
|
||||
warnings.simplefilter("ignore")
|
||||
lengths = (x_lens + 1) // 2
|
||||
|
||||
return x, lengths
|
||||
return x, lengths, saved
|
||||
|
||||
def _get_attn_mask(
|
||||
self, x: Tensor,
|
||||
|
Loading…
x
Reference in New Issue
Block a user