add smooth-regularized CTC

This commit is contained in:
yaozengwei 2024-10-10 16:47:57 +08:00
parent ae59e5d61e
commit a85592dc78
2 changed files with 99 additions and 68 deletions

View File

@ -16,11 +16,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple from typing import List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
@ -111,9 +112,8 @@ class AsrModel(nn.Module):
if use_ctc: if use_ctc:
# Modules for CTC head # Modules for CTC head
self.ctc_output = nn.Sequential( self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1), nn.Dropout(p=0.1), # TODO: test removing this
nn.Linear(encoder_dim, vocab_size), nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
) )
self.use_attention_decoder = use_attention_decoder self.use_attention_decoder = use_attention_decoder
@ -159,71 +159,82 @@ class AsrModel(nn.Module):
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
targets: torch.Tensor, targets: torch.Tensor,
target_lengths: torch.Tensor, target_lengths: torch.Tensor,
) -> torch.Tensor: use_consistency_reg: bool = False,
"""Compute CTC loss. use_smooth_reg: bool = False,
Args: smooth_kernel: List[float] = [0.25, 0.5, 0.25],
encoder_out: eps: float = 1e-6,
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss
def forward_cr_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss. """Compute CTC loss with consistency regularization loss.
Args: Args:
encoder_out: encoder_out:
Encoder output, of shape (2 * N, T, C). Encoder output, of shape (N or 2 * N, T, C).
encoder_out_lens: encoder_out_lens:
Encoder output lengths, of shape (2 * N,). Encoder output lengths, of shape (N or 2 * N,).
targets: targets:
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension. to be un-padded and concatenated within 1 dimension.
use_consistency_reg:
Whether use consistency regularization.
use_smooth_reg:
Whether use smooth regularization.
""" """
ctc_output = self.ctc_output(encoder_out) # (N or 2 * N, T, C)
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
if not use_smooth_reg:
ctc_log_probs = F.log_softmax(ctc_output, dim=-1)
else:
ctc_probs = ctc_output.softmax(dim=-1) # Used in sr_loss
ctc_log_probs = (ctc_probs + eps).log()
# Compute CTC loss # Compute CTC loss
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
ctc_loss = torch.nn.functional.ctc_loss( ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) log_probs=ctc_log_probs.permute(1, 0, 2), # (T, N or 2 * N, C)
targets=targets.cpu(), targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(), input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(), target_lengths=target_lengths.cpu(),
reduction="sum", reduction="sum",
) )
# Compute consistency regularization loss if use_consistency_reg:
exchanged_targets = ctc_output.detach().chunk(2, dim=0) assert ctc_log_probs.shape[0] % 2 == 0
# Compute cr_loss
exchanged_targets = ctc_log_probs.detach().chunk(2, dim=0)
exchanged_targets = torch.cat( exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0 [exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1] ) # exchange: [x1, x2] -> [x2, x1]
cr_loss = nn.functional.kl_div( cr_loss = nn.functional.kl_div(
input=ctc_output, input=ctc_log_probs,
target=exchanged_targets, target=exchanged_targets,
reduction="none", reduction="none",
log_target=True, log_target=True,
) # (2 * N, T, C) ) # (2 * N, T, C)
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
else:
cr_loss = torch.empty(0)
return ctc_loss, cr_loss if use_smooth_reg:
# Hard code the kernel here, could try other values
assert len(smooth_kernel) == 3 and sum(smooth_kernel) == 1.0, smooth_kernel
smooth_kernel = torch.tensor(smooth_kernel, dtype=ctc_probs.dtype,
device=ctc_probs.device, requires_grad=False)
smooth_kernel = smooth_kernel.unsqueeze(0).unsqueeze(1).expand(ctc_probs.shape[-1], 1, 3)
# Now kernel: (C, 1, 3)
smoothed_ctc_probs = F.conv1d(
ctc_probs.detach().permute(0, 2, 1), # (N or 2 * N, C, T)
weight=smooth_kernel, stride=1, padding=0, groups=ctc_probs.shape[-1]
).permute(0, 2, 1) # (N or 2 * N, T - 2, C)
sr_loss = nn.functional.kl_div(
input=ctc_log_probs[:, 1:-1],
target=(smoothed_ctc_probs + eps).log(),
reduction="none",
log_target=True,
) # (N, T - 1 , C)
sr_loss = sr_loss.masked_fill(length_mask[:, 1:-1], 0.0).sum()
else:
sr_loss = torch.empty(0)
return ctc_loss, cr_loss, sr_loss
def forward_transducer( def forward_transducer(
self, self,
@ -341,6 +352,7 @@ class AsrModel(nn.Module):
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
use_cr_ctc: bool = False, use_cr_ctc: bool = False,
use_sr_ctc: bool = False,
use_spec_aug: bool = False, use_spec_aug: bool = False,
spec_augment: Optional[SpecAugment] = None, spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None, supervision_segments: Optional[torch.Tensor] = None,
@ -367,6 +379,8 @@ class AsrModel(nn.Module):
part part
use_cr_ctc: use_cr_ctc:
Whether use consistency-regularized CTC. Whether use consistency-regularized CTC.
use_sr_ctc:
Whether use smooth-regularized CTC.
use_spec_aug: use_spec_aug:
Whether apply spec-augment manually, used only if use_cr_ctc is True. Whether apply spec-augment manually, used only if use_cr_ctc is True.
spec_augment: spec_augment:
@ -445,26 +459,23 @@ class AsrModel(nn.Module):
if self.use_ctc: if self.use_ctc:
# Compute CTC loss # Compute CTC loss
targets = y.values targets = y.values
if not use_cr_ctc: ctc_loss, cr_loss, sr_loss = self.forward_ctc(
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
cr_loss = torch.empty(0)
else:
ctc_loss, cr_loss = self.forward_cr_ctc(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
targets=targets, targets=targets,
target_lengths=y_lens, target_lengths=y_lens,
use_consistency_reg=use_cr_ctc,
use_smooth_reg=use_sr_ctc,
) )
if use_cr_ctc:
# We duplicate the batch when use_cr_ctc is True
ctc_loss = ctc_loss * 0.5 ctc_loss = ctc_loss * 0.5
cr_loss = cr_loss * 0.5 cr_loss = cr_loss * 0.5
sr_loss = sr_loss * 0.5
else: else:
ctc_loss = torch.empty(0) ctc_loss = torch.empty(0)
cr_loss = torch.empty(0) cr_loss = torch.empty(0)
sr_loss = torch.empty(0)
if self.use_attention_decoder: if self.use_attention_decoder:
attention_decoder_loss = self.attention_decoder.calc_att_loss( attention_decoder_loss = self.attention_decoder.calc_att_loss(
@ -478,4 +489,4 @@ class AsrModel(nn.Module):
else: else:
attention_decoder_loss = torch.empty(0) attention_decoder_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss

View File

@ -312,6 +312,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="If True, use consistency-regularized CTC.", help="If True, use consistency-regularized CTC.",
) )
parser.add_argument(
"--use-sr-ctc",
type=str2bool,
default=False,
help="If True, use smooth-regularized CTC.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -464,6 +471,13 @@ def get_parser():
help="Scale for consistency-regularization loss.", help="Scale for consistency-regularization loss.",
) )
parser.add_argument(
"--sr-loss-scale",
type=float,
default=0.2,
help="Scale for smooth-regularization loss.",
)
parser.add_argument( parser.add_argument(
"--time-mask-ratio", "--time-mask-ratio",
type=float, type=float,
@ -916,6 +930,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
use_cr_ctc = params.use_cr_ctc use_cr_ctc = params.use_cr_ctc
use_sr_ctc = params.use_sr_ctc
use_spec_aug = use_cr_ctc and is_training use_spec_aug = use_cr_ctc and is_training
if use_spec_aug: if use_spec_aug:
supervision_intervals = batch["supervisions"] supervision_intervals = batch["supervisions"]
@ -931,7 +946,7 @@ def compute_loss(
supervision_segments = None supervision_segments = None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -939,6 +954,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
use_cr_ctc=use_cr_ctc, use_cr_ctc=use_cr_ctc,
use_sr_ctc=use_sr_ctc,
use_spec_aug=use_spec_aug, use_spec_aug=use_spec_aug,
spec_augment=spec_augment, spec_augment=spec_augment,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
@ -967,6 +983,8 @@ def compute_loss(
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss
if use_cr_ctc: if use_cr_ctc:
loss += params.cr_loss_scale * cr_loss loss += params.cr_loss_scale * cr_loss
if use_sr_ctc:
loss += params.sr_loss_scale * sr_loss
if params.use_attention_decoder: if params.use_attention_decoder:
loss += params.attention_decoder_loss_scale * attention_decoder_loss loss += params.attention_decoder_loss_scale * attention_decoder_loss
@ -985,8 +1003,10 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc: if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item() info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.use_cr_ctc: if use_cr_ctc:
info["cr_loss"] = cr_loss.detach().cpu().item() info["cr_loss"] = cr_loss.detach().cpu().item()
if use_sr_ctc:
info["sr_loss"] = sr_loss.detach().cpu().item()
if params.use_attention_decoder: if params.use_attention_decoder:
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()