Merge a85592dc784de1ac77ca4dac861b46b5e13b539d into f84270c93528f4b77b99ada9ac0c9f7fb231d6a4

This commit is contained in:
Zengwei Yao 2024-10-17 17:01:20 +08:00 committed by GitHub
commit 0be47510a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 273 additions and 21 deletions

View File

@ -16,15 +16,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask
from icefall.utils import add_sos, make_pad_mask, time_warp
from lhotse.dataset import SpecAugment
class AsrModel(nn.Module):
@ -110,9 +112,8 @@ class AsrModel(nn.Module):
if use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Dropout(p=0.1), # TODO: test removing this
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
self.use_attention_decoder = use_attention_decoder
@ -158,28 +159,82 @@ class AsrModel(nn.Module):
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
use_consistency_reg: bool = False,
use_smooth_reg: bool = False,
smooth_kernel: List[float] = [0.25, 0.5, 0.25],
eps: float = 1e-6,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
Encoder output, of shape (N or 2 * N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
Encoder output lengths, of shape (N or 2 * N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
use_consistency_reg:
Whether use consistency regularization.
use_smooth_reg:
Whether use smooth regularization.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_output = self.ctc_output(encoder_out) # (N or 2 * N, T, C)
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
if not use_smooth_reg:
ctc_log_probs = F.log_softmax(ctc_output, dim=-1)
else:
ctc_probs = ctc_output.softmax(dim=-1) # Used in sr_loss
ctc_log_probs = (ctc_probs + eps).log()
# Compute CTC loss
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
log_probs=ctc_log_probs.permute(1, 0, 2), # (T, N or 2 * N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss
if use_consistency_reg:
assert ctc_log_probs.shape[0] % 2 == 0
# Compute cr_loss
exchanged_targets = ctc_log_probs.detach().chunk(2, dim=0)
exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1]
cr_loss = nn.functional.kl_div(
input=ctc_log_probs,
target=exchanged_targets,
reduction="none",
log_target=True,
) # (2 * N, T, C)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
else:
cr_loss = torch.empty(0)
if use_smooth_reg:
# Hard code the kernel here, could try other values
assert len(smooth_kernel) == 3 and sum(smooth_kernel) == 1.0, smooth_kernel
smooth_kernel = torch.tensor(smooth_kernel, dtype=ctc_probs.dtype,
device=ctc_probs.device, requires_grad=False)
smooth_kernel = smooth_kernel.unsqueeze(0).unsqueeze(1).expand(ctc_probs.shape[-1], 1, 3)
# Now kernel: (C, 1, 3)
smoothed_ctc_probs = F.conv1d(
ctc_probs.detach().permute(0, 2, 1), # (N or 2 * N, C, T)
weight=smooth_kernel, stride=1, padding=0, groups=ctc_probs.shape[-1]
).permute(0, 2, 1) # (N or 2 * N, T - 2, C)
sr_loss = nn.functional.kl_div(
input=ctc_log_probs[:, 1:-1],
target=(smoothed_ctc_probs + eps).log(),
reduction="none",
log_target=True,
) # (N, T - 1 , C)
sr_loss = sr_loss.masked_fill(length_mask[:, 1:-1], 0.0).sum()
else:
sr_loss = torch.empty(0)
return ctc_loss, cr_loss, sr_loss
def forward_transducer(
self,
@ -296,7 +351,13 @@ class AsrModel(nn.Module):
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
use_cr_ctc: bool = False,
use_sr_ctc: bool = False,
use_spec_aug: bool = False,
spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
@ -316,9 +377,28 @@ class AsrModel(nn.Module):
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
use_cr_ctc:
Whether use consistency-regularized CTC.
use_sr_ctc:
Whether use smooth-regularized CTC.
use_spec_aug:
Whether apply spec-augment manually, used only if use_cr_ctc is True.
spec_augment:
The SpecAugment instance that returns time masks,
used only if use_cr_ctc is True.
supervision_segments:
An int tensor of shape ``(S, 3)``. ``S`` is the number of
supervision segments that exist in ``features``.
Used only if use_cr_ctc is True.
time_warp_factor:
Parameter for the time warping; larger values mean more warping.
Set to ``None``, or less than ``1``, to disable.
Used only if use_cr_ctc is True.
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss)
Return the transducer losses, CTC loss, AED loss,
and consistency-regularization loss in form of
(simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
@ -334,6 +414,24 @@ class AsrModel(nn.Module):
device = x.device
if use_cr_ctc:
assert self.use_ctc
if use_spec_aug:
assert spec_augment is not None and spec_augment.time_warp_factor < 1
# Apply time warping before input duplicating
assert supervision_segments is not None
x = time_warp(
x,
time_warp_factor=time_warp_factor,
supervision_segments=supervision_segments,
)
# Independently apply frequency masking and time masking to the two copies
x = spec_augment(x.repeat(2, 1, 1))
else:
x = x.repeat(2, 1, 1)
x_lens = x_lens.repeat(2)
y = k2.ragged.cat([y, y], axis=0)
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
@ -351,6 +449,9 @@ class AsrModel(nn.Module):
am_scale=am_scale,
lm_scale=lm_scale,
)
if use_cr_ctc:
simple_loss = simple_loss * 0.5
pruned_loss = pruned_loss * 0.5
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
@ -358,14 +459,23 @@ class AsrModel(nn.Module):
if self.use_ctc:
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
ctc_loss, cr_loss, sr_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
use_consistency_reg=use_cr_ctc,
use_smooth_reg=use_sr_ctc,
)
if use_cr_ctc:
# We duplicate the batch when use_cr_ctc is True
ctc_loss = ctc_loss * 0.5
cr_loss = cr_loss * 0.5
sr_loss = sr_loss * 0.5
else:
ctc_loss = torch.empty(0)
cr_loss = torch.empty(0)
sr_loss = torch.empty(0)
if self.use_attention_decoder:
attention_decoder_loss = self.attention_decoder.calc_att_loss(
@ -374,7 +484,9 @@ class AsrModel(nn.Module):
ys=y.to(device),
ys_lens=y_lens.to(device),
)
if use_cr_ctc:
attention_decoder_loss = attention_decoder_loss * 0.5
else:
attention_decoder_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss

View File

@ -72,6 +72,7 @@ from attention_decoder import AttentionDecoderModel
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset import SpecAugment
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import AsrModel
@ -304,6 +305,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="If True, use attention-decoder head.",
)
parser.add_argument(
"--use-cr-ctc",
type=str2bool,
default=False,
help="If True, use consistency-regularized CTC.",
)
parser.add_argument(
"--use-sr-ctc",
type=str2bool,
default=False,
help="If True, use smooth-regularized CTC.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -449,6 +464,27 @@ def get_parser():
help="Scale for CTC loss.",
)
parser.add_argument(
"--cr-loss-scale",
type=float,
default=0.2,
help="Scale for consistency-regularization loss.",
)
parser.add_argument(
"--sr-loss-scale",
type=float,
default=0.2,
help="Scale for smooth-regularization loss.",
)
parser.add_argument(
"--time-mask-ratio",
type=float,
default=2.5,
help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.",
)
parser.add_argument(
"--attention-decoder-loss-scale",
type=float,
@ -717,6 +753,24 @@ def get_model(params: AttributeDict) -> nn.Module:
return model
def get_spec_augment(params: AttributeDict) -> SpecAugment:
num_frame_masks = int(10 * params.time_mask_ratio)
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
logging.info(
f"num_frame_masks: {num_frame_masks}, "
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
)
spec_augment = SpecAugment(
time_warp_factor=0, # Do time warping in model.py
num_frame_masks=num_frame_masks, # default: 10
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
)
return spec_augment
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
@ -839,6 +893,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
spec_augment: Optional[SpecAugment] = None,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute loss given the model and its inputs.
@ -855,8 +910,8 @@ def compute_loss(
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
spec_augment:
The SpecAugment instance used only when use_cr_ctc is True.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
@ -874,14 +929,36 @@ def compute_loss(
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
use_cr_ctc = params.use_cr_ctc
use_sr_ctc = params.use_sr_ctc
use_spec_aug = use_cr_ctc and is_training
if use_spec_aug:
supervision_intervals = batch["supervisions"]
supervision_segments = torch.stack(
[
supervision_intervals["sequence_idx"],
supervision_intervals["start_frame"],
supervision_intervals["num_frames"],
],
dim=1,
) # shape: (S, 3)
else:
supervision_segments = None
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model(
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
use_cr_ctc=use_cr_ctc,
use_sr_ctc=use_sr_ctc,
use_spec_aug=use_spec_aug,
spec_augment=spec_augment,
supervision_segments=supervision_segments,
time_warp_factor=params.spec_aug_time_warp_factor,
)
loss = 0.0
@ -904,6 +981,10 @@ def compute_loss(
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
if use_cr_ctc:
loss += params.cr_loss_scale * cr_loss
if use_sr_ctc:
loss += params.sr_loss_scale * sr_loss
if params.use_attention_decoder:
loss += params.attention_decoder_loss_scale * attention_decoder_loss
@ -922,6 +1003,10 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if use_cr_ctc:
info["cr_loss"] = cr_loss.detach().cpu().item()
if use_sr_ctc:
info["sr_loss"] = sr_loss.detach().cpu().item()
if params.use_attention_decoder:
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
@ -971,6 +1056,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
spec_augment: Optional[SpecAugment] = None,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -997,6 +1083,8 @@ def train_one_epoch(
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
spec_augment:
The SpecAugment instance used only when use_cr_ctc is True.
model_avg:
The stored model averaged from the start of training.
tb_writer:
@ -1043,6 +1131,7 @@ def train_one_epoch(
sp=sp,
batch=batch,
is_training=True,
spec_augment=spec_augment,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -1238,6 +1327,13 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
if params.use_cr_ctc:
assert params.use_ctc
assert not params.enable_spec_aug # we will do spec_augment in model.py
spec_augment = get_spec_augment(params)
else:
spec_augment = None
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
@ -1360,6 +1456,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
spec_augment=spec_augment,
)
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
@ -1387,6 +1484,7 @@ def run(rank, world_size, args):
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
spec_augment=spec_augment,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
@ -1452,6 +1550,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
spec_augment: Optional[SpecAugment] = None,
):
from lhotse.dataset import find_pessimistic_batches
@ -1471,6 +1570,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
spec_augment=spec_augment,
)
loss.backward()
optimizer.zero_grad()

View File

@ -21,6 +21,7 @@ import argparse
import collections
import logging
import os
import random
import re
import subprocess
from collections import defaultdict
@ -38,6 +39,7 @@ import sentencepiece as spm
import torch
import torch.distributed as dist
import torch.nn as nn
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
from pypinyin import lazy_pinyin, pinyin
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
from torch.utils.tensorboard import SummaryWriter
@ -2271,3 +2273,41 @@ def num_tokens(
if 0 in ans:
num_tokens -= 1
return num_tokens
# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py
def time_warp(
features: torch.Tensor,
p: float = 0.9,
time_warp_factor: Optional[int] = 80,
supervision_segments: Optional[torch.Tensor] = None,
):
"""Apply time warping on a batch of features
"""
if time_warp_factor is None or time_warp_factor < 1:
return features
assert len(features.shape) == 3, (
"SpecAugment only supports batches of single-channel feature matrices."
)
features = features.clone()
if supervision_segments is None:
# No supervisions - apply spec augment to full feature matrices.
for sequence_idx in range(features.size(0)):
if random.random() > p:
# Randomly choose whether this transform is applied
continue
features[sequence_idx] = time_warp_impl(
features[sequence_idx], factor=time_warp_factor
)
else:
# Supervisions provided - we will apply time warping only on the supervised areas.
for sequence_idx, start_frame, num_frames in supervision_segments:
if random.random() > p:
# Randomly choose whether this transform is applied
continue
end_frame = start_frame + num_frames
features[sequence_idx, start_frame:end_frame] = time_warp_impl(
features[sequence_idx, start_frame:end_frame], factor=time_warp_factor
)
return features