mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add smooth-regularized CTC
This commit is contained in:
parent
ae59e5d61e
commit
a85592dc78
@ -16,11 +16,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
@ -111,9 +112,8 @@ class AsrModel(nn.Module):
|
||||
if use_ctc:
|
||||
# Modules for CTC head
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Dropout(p=0.1), # TODO: test removing this
|
||||
nn.Linear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
self.use_attention_decoder = use_attention_decoder
|
||||
@ -159,71 +159,82 @@ class AsrModel(nn.Module):
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
targets:
|
||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
"""
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
targets=targets.cpu(),
|
||||
input_lengths=encoder_out_lens.cpu(),
|
||||
target_lengths=target_lengths.cpu(),
|
||||
reduction="sum",
|
||||
)
|
||||
return ctc_loss
|
||||
|
||||
def forward_cr_ctc(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
use_consistency_reg: bool = False,
|
||||
use_smooth_reg: bool = False,
|
||||
smooth_kernel: List[float] = [0.25, 0.5, 0.25],
|
||||
eps: float = 1e-6,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute CTC loss with consistency regularization loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (2 * N, T, C).
|
||||
Encoder output, of shape (N or 2 * N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (2 * N,).
|
||||
Encoder output lengths, of shape (N or 2 * N,).
|
||||
targets:
|
||||
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
use_consistency_reg:
|
||||
Whether use consistency regularization.
|
||||
use_smooth_reg:
|
||||
Whether use smooth regularization.
|
||||
"""
|
||||
ctc_output = self.ctc_output(encoder_out) # (N or 2 * N, T, C)
|
||||
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
|
||||
|
||||
if not use_smooth_reg:
|
||||
ctc_log_probs = F.log_softmax(ctc_output, dim=-1)
|
||||
else:
|
||||
ctc_probs = ctc_output.softmax(dim=-1) # Used in sr_loss
|
||||
ctc_log_probs = (ctc_probs + eps).log()
|
||||
|
||||
# Compute CTC loss
|
||||
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
|
||||
log_probs=ctc_log_probs.permute(1, 0, 2), # (T, N or 2 * N, C)
|
||||
targets=targets.cpu(),
|
||||
input_lengths=encoder_out_lens.cpu(),
|
||||
target_lengths=target_lengths.cpu(),
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
# Compute consistency regularization loss
|
||||
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
|
||||
exchanged_targets = torch.cat(
|
||||
[exchanged_targets[1], exchanged_targets[0]], dim=0
|
||||
) # exchange: [x1, x2] -> [x2, x1]
|
||||
cr_loss = nn.functional.kl_div(
|
||||
input=ctc_output,
|
||||
target=exchanged_targets,
|
||||
reduction="none",
|
||||
log_target=True,
|
||||
) # (2 * N, T, C)
|
||||
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
|
||||
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
|
||||
if use_consistency_reg:
|
||||
assert ctc_log_probs.shape[0] % 2 == 0
|
||||
# Compute cr_loss
|
||||
exchanged_targets = ctc_log_probs.detach().chunk(2, dim=0)
|
||||
exchanged_targets = torch.cat(
|
||||
[exchanged_targets[1], exchanged_targets[0]], dim=0
|
||||
) # exchange: [x1, x2] -> [x2, x1]
|
||||
cr_loss = nn.functional.kl_div(
|
||||
input=ctc_log_probs,
|
||||
target=exchanged_targets,
|
||||
reduction="none",
|
||||
log_target=True,
|
||||
) # (2 * N, T, C)
|
||||
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
|
||||
else:
|
||||
cr_loss = torch.empty(0)
|
||||
|
||||
return ctc_loss, cr_loss
|
||||
if use_smooth_reg:
|
||||
# Hard code the kernel here, could try other values
|
||||
assert len(smooth_kernel) == 3 and sum(smooth_kernel) == 1.0, smooth_kernel
|
||||
smooth_kernel = torch.tensor(smooth_kernel, dtype=ctc_probs.dtype,
|
||||
device=ctc_probs.device, requires_grad=False)
|
||||
smooth_kernel = smooth_kernel.unsqueeze(0).unsqueeze(1).expand(ctc_probs.shape[-1], 1, 3)
|
||||
# Now kernel: (C, 1, 3)
|
||||
smoothed_ctc_probs = F.conv1d(
|
||||
ctc_probs.detach().permute(0, 2, 1), # (N or 2 * N, C, T)
|
||||
weight=smooth_kernel, stride=1, padding=0, groups=ctc_probs.shape[-1]
|
||||
).permute(0, 2, 1) # (N or 2 * N, T - 2, C)
|
||||
sr_loss = nn.functional.kl_div(
|
||||
input=ctc_log_probs[:, 1:-1],
|
||||
target=(smoothed_ctc_probs + eps).log(),
|
||||
reduction="none",
|
||||
log_target=True,
|
||||
) # (N, T - 1 , C)
|
||||
sr_loss = sr_loss.masked_fill(length_mask[:, 1:-1], 0.0).sum()
|
||||
else:
|
||||
sr_loss = torch.empty(0)
|
||||
|
||||
return ctc_loss, cr_loss, sr_loss
|
||||
|
||||
def forward_transducer(
|
||||
self,
|
||||
@ -341,6 +352,7 @@ class AsrModel(nn.Module):
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
use_cr_ctc: bool = False,
|
||||
use_sr_ctc: bool = False,
|
||||
use_spec_aug: bool = False,
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
supervision_segments: Optional[torch.Tensor] = None,
|
||||
@ -367,6 +379,8 @@ class AsrModel(nn.Module):
|
||||
part
|
||||
use_cr_ctc:
|
||||
Whether use consistency-regularized CTC.
|
||||
use_sr_ctc:
|
||||
Whether use smooth-regularized CTC.
|
||||
use_spec_aug:
|
||||
Whether apply spec-augment manually, used only if use_cr_ctc is True.
|
||||
spec_augment:
|
||||
@ -445,26 +459,23 @@ class AsrModel(nn.Module):
|
||||
if self.use_ctc:
|
||||
# Compute CTC loss
|
||||
targets = y.values
|
||||
if not use_cr_ctc:
|
||||
ctc_loss = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
cr_loss = torch.empty(0)
|
||||
else:
|
||||
ctc_loss, cr_loss = self.forward_cr_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
ctc_loss, cr_loss, sr_loss = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
use_consistency_reg=use_cr_ctc,
|
||||
use_smooth_reg=use_sr_ctc,
|
||||
)
|
||||
if use_cr_ctc:
|
||||
# We duplicate the batch when use_cr_ctc is True
|
||||
ctc_loss = ctc_loss * 0.5
|
||||
cr_loss = cr_loss * 0.5
|
||||
sr_loss = sr_loss * 0.5
|
||||
else:
|
||||
ctc_loss = torch.empty(0)
|
||||
cr_loss = torch.empty(0)
|
||||
sr_loss = torch.empty(0)
|
||||
|
||||
if self.use_attention_decoder:
|
||||
attention_decoder_loss = self.attention_decoder.calc_att_loss(
|
||||
@ -478,4 +489,4 @@ class AsrModel(nn.Module):
|
||||
else:
|
||||
attention_decoder_loss = torch.empty(0)
|
||||
|
||||
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss
|
||||
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss
|
||||
|
@ -312,6 +312,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="If True, use consistency-regularized CTC.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-sr-ctc",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If True, use smooth-regularized CTC.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -464,6 +471,13 @@ def get_parser():
|
||||
help="Scale for consistency-regularization loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sr-loss-scale",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Scale for smooth-regularization loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--time-mask-ratio",
|
||||
type=float,
|
||||
@ -916,6 +930,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
use_cr_ctc = params.use_cr_ctc
|
||||
use_sr_ctc = params.use_sr_ctc
|
||||
use_spec_aug = use_cr_ctc and is_training
|
||||
if use_spec_aug:
|
||||
supervision_intervals = batch["supervisions"]
|
||||
@ -931,7 +946,7 @@ def compute_loss(
|
||||
supervision_segments = None
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model(
|
||||
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -939,6 +954,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
use_cr_ctc=use_cr_ctc,
|
||||
use_sr_ctc=use_sr_ctc,
|
||||
use_spec_aug=use_spec_aug,
|
||||
spec_augment=spec_augment,
|
||||
supervision_segments=supervision_segments,
|
||||
@ -967,6 +983,8 @@ def compute_loss(
|
||||
loss += params.ctc_loss_scale * ctc_loss
|
||||
if use_cr_ctc:
|
||||
loss += params.cr_loss_scale * cr_loss
|
||||
if use_sr_ctc:
|
||||
loss += params.sr_loss_scale * sr_loss
|
||||
|
||||
if params.use_attention_decoder:
|
||||
loss += params.attention_decoder_loss_scale * attention_decoder_loss
|
||||
@ -985,8 +1003,10 @@ def compute_loss(
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
if params.use_ctc:
|
||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||
if params.use_cr_ctc:
|
||||
if use_cr_ctc:
|
||||
info["cr_loss"] = cr_loss.detach().cpu().item()
|
||||
if use_sr_ctc:
|
||||
info["sr_loss"] = sr_loss.detach().cpu().item()
|
||||
if params.use_attention_decoder:
|
||||
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user