Merge a85592dc784de1ac77ca4dac861b46b5e13b539d into f84270c93528f4b77b99ada9ac0c9f7fb231d6a4

This commit is contained in:
Zengwei Yao 2024-10-17 17:01:20 +08:00 committed by GitHub
commit 0be47510a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 273 additions and 21 deletions

View File

@ -16,15 +16,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple from typing import List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask, time_warp
from lhotse.dataset import SpecAugment
class AsrModel(nn.Module): class AsrModel(nn.Module):
@ -110,9 +112,8 @@ class AsrModel(nn.Module):
if use_ctc: if use_ctc:
# Modules for CTC head # Modules for CTC head
self.ctc_output = nn.Sequential( self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1), nn.Dropout(p=0.1), # TODO: test removing this
nn.Linear(encoder_dim, vocab_size), nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
) )
self.use_attention_decoder = use_attention_decoder self.use_attention_decoder = use_attention_decoder
@ -158,28 +159,82 @@ class AsrModel(nn.Module):
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
targets: torch.Tensor, targets: torch.Tensor,
target_lengths: torch.Tensor, target_lengths: torch.Tensor,
) -> torch.Tensor: use_consistency_reg: bool = False,
"""Compute CTC loss. use_smooth_reg: bool = False,
smooth_kernel: List[float] = [0.25, 0.5, 0.25],
eps: float = 1e-6,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss.
Args: Args:
encoder_out: encoder_out:
Encoder output, of shape (N, T, C). Encoder output, of shape (N or 2 * N, T, C).
encoder_out_lens: encoder_out_lens:
Encoder output lengths, of shape (N,). Encoder output lengths, of shape (N or 2 * N,).
targets: targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension. to be un-padded and concatenated within 1 dimension.
use_consistency_reg:
Whether use consistency regularization.
use_smooth_reg:
Whether use smooth regularization.
""" """
# Compute CTC log-prob ctc_output = self.ctc_output(encoder_out) # (N or 2 * N, T, C)
ctc_output = self.ctc_output(encoder_out) # (N, T, C) length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
if not use_smooth_reg:
ctc_log_probs = F.log_softmax(ctc_output, dim=-1)
else:
ctc_probs = ctc_output.softmax(dim=-1) # Used in sr_loss
ctc_log_probs = (ctc_probs + eps).log()
# Compute CTC loss
ctc_loss = torch.nn.functional.ctc_loss( ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) log_probs=ctc_log_probs.permute(1, 0, 2), # (T, N or 2 * N, C)
targets=targets.cpu(), targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(), input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(), target_lengths=target_lengths.cpu(),
reduction="sum", reduction="sum",
) )
return ctc_loss
if use_consistency_reg:
assert ctc_log_probs.shape[0] % 2 == 0
# Compute cr_loss
exchanged_targets = ctc_log_probs.detach().chunk(2, dim=0)
exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1]
cr_loss = nn.functional.kl_div(
input=ctc_log_probs,
target=exchanged_targets,
reduction="none",
log_target=True,
) # (2 * N, T, C)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
else:
cr_loss = torch.empty(0)
if use_smooth_reg:
# Hard code the kernel here, could try other values
assert len(smooth_kernel) == 3 and sum(smooth_kernel) == 1.0, smooth_kernel
smooth_kernel = torch.tensor(smooth_kernel, dtype=ctc_probs.dtype,
device=ctc_probs.device, requires_grad=False)
smooth_kernel = smooth_kernel.unsqueeze(0).unsqueeze(1).expand(ctc_probs.shape[-1], 1, 3)
# Now kernel: (C, 1, 3)
smoothed_ctc_probs = F.conv1d(
ctc_probs.detach().permute(0, 2, 1), # (N or 2 * N, C, T)
weight=smooth_kernel, stride=1, padding=0, groups=ctc_probs.shape[-1]
).permute(0, 2, 1) # (N or 2 * N, T - 2, C)
sr_loss = nn.functional.kl_div(
input=ctc_log_probs[:, 1:-1],
target=(smoothed_ctc_probs + eps).log(),
reduction="none",
log_target=True,
) # (N, T - 1 , C)
sr_loss = sr_loss.masked_fill(length_mask[:, 1:-1], 0.0).sum()
else:
sr_loss = torch.empty(0)
return ctc_loss, cr_loss, sr_loss
def forward_transducer( def forward_transducer(
self, self,
@ -296,7 +351,13 @@ class AsrModel(nn.Module):
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: use_cr_ctc: bool = False,
use_sr_ctc: bool = False,
use_spec_aug: bool = False,
spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Args: Args:
x: x:
@ -316,9 +377,28 @@ class AsrModel(nn.Module):
lm_scale: lm_scale:
The scale to smooth the loss with lm (output of predictor network) The scale to smooth the loss with lm (output of predictor network)
part part
use_cr_ctc:
Whether use consistency-regularized CTC.
use_sr_ctc:
Whether use smooth-regularized CTC.
use_spec_aug:
Whether apply spec-augment manually, used only if use_cr_ctc is True.
spec_augment:
The SpecAugment instance that returns time masks,
used only if use_cr_ctc is True.
supervision_segments:
An int tensor of shape ``(S, 3)``. ``S`` is the number of
supervision segments that exist in ``features``.
Used only if use_cr_ctc is True.
time_warp_factor:
Parameter for the time warping; larger values mean more warping.
Set to ``None``, or less than ``1``, to disable.
Used only if use_cr_ctc is True.
Returns: Returns:
Return the transducer losses and CTC loss, Return the transducer losses, CTC loss, AED loss,
in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss) and consistency-regularization loss in form of
(simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss)
Note: Note:
Regarding am_scale & lm_scale, it will make the loss-function one of Regarding am_scale & lm_scale, it will make the loss-function one of
@ -334,6 +414,24 @@ class AsrModel(nn.Module):
device = x.device device = x.device
if use_cr_ctc:
assert self.use_ctc
if use_spec_aug:
assert spec_augment is not None and spec_augment.time_warp_factor < 1
# Apply time warping before input duplicating
assert supervision_segments is not None
x = time_warp(
x,
time_warp_factor=time_warp_factor,
supervision_segments=supervision_segments,
)
# Independently apply frequency masking and time masking to the two copies
x = spec_augment(x.repeat(2, 1, 1))
else:
x = x.repeat(2, 1, 1)
x_lens = x_lens.repeat(2)
y = k2.ragged.cat([y, y], axis=0)
# Compute encoder outputs # Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
@ -351,6 +449,9 @@ class AsrModel(nn.Module):
am_scale=am_scale, am_scale=am_scale,
lm_scale=lm_scale, lm_scale=lm_scale,
) )
if use_cr_ctc:
simple_loss = simple_loss * 0.5
pruned_loss = pruned_loss * 0.5
else: else:
simple_loss = torch.empty(0) simple_loss = torch.empty(0)
pruned_loss = torch.empty(0) pruned_loss = torch.empty(0)
@ -358,14 +459,23 @@ class AsrModel(nn.Module):
if self.use_ctc: if self.use_ctc:
# Compute CTC loss # Compute CTC loss
targets = y.values targets = y.values
ctc_loss = self.forward_ctc( ctc_loss, cr_loss, sr_loss = self.forward_ctc(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
targets=targets, targets=targets,
target_lengths=y_lens, target_lengths=y_lens,
use_consistency_reg=use_cr_ctc,
use_smooth_reg=use_sr_ctc,
) )
if use_cr_ctc:
# We duplicate the batch when use_cr_ctc is True
ctc_loss = ctc_loss * 0.5
cr_loss = cr_loss * 0.5
sr_loss = sr_loss * 0.5
else: else:
ctc_loss = torch.empty(0) ctc_loss = torch.empty(0)
cr_loss = torch.empty(0)
sr_loss = torch.empty(0)
if self.use_attention_decoder: if self.use_attention_decoder:
attention_decoder_loss = self.attention_decoder.calc_att_loss( attention_decoder_loss = self.attention_decoder.calc_att_loss(
@ -374,7 +484,9 @@ class AsrModel(nn.Module):
ys=y.to(device), ys=y.to(device),
ys_lens=y_lens.to(device), ys_lens=y_lens.to(device),
) )
if use_cr_ctc:
attention_decoder_loss = attention_decoder_loss * 0.5
else: else:
attention_decoder_loss = torch.empty(0) attention_decoder_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss

View File

@ -72,6 +72,7 @@ from attention_decoder import AttentionDecoderModel
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset import SpecAugment
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
@ -304,6 +305,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="If True, use attention-decoder head.", help="If True, use attention-decoder head.",
) )
parser.add_argument(
"--use-cr-ctc",
type=str2bool,
default=False,
help="If True, use consistency-regularized CTC.",
)
parser.add_argument(
"--use-sr-ctc",
type=str2bool,
default=False,
help="If True, use smooth-regularized CTC.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -449,6 +464,27 @@ def get_parser():
help="Scale for CTC loss.", help="Scale for CTC loss.",
) )
parser.add_argument(
"--cr-loss-scale",
type=float,
default=0.2,
help="Scale for consistency-regularization loss.",
)
parser.add_argument(
"--sr-loss-scale",
type=float,
default=0.2,
help="Scale for smooth-regularization loss.",
)
parser.add_argument(
"--time-mask-ratio",
type=float,
default=2.5,
help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.",
)
parser.add_argument( parser.add_argument(
"--attention-decoder-loss-scale", "--attention-decoder-loss-scale",
type=float, type=float,
@ -717,6 +753,24 @@ def get_model(params: AttributeDict) -> nn.Module:
return model return model
def get_spec_augment(params: AttributeDict) -> SpecAugment:
num_frame_masks = int(10 * params.time_mask_ratio)
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
logging.info(
f"num_frame_masks: {num_frame_masks}, "
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
)
spec_augment = SpecAugment(
time_warp_factor=0, # Do time warping in model.py
num_frame_masks=num_frame_masks, # default: 10
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
)
return spec_augment
def load_checkpoint_if_available( def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -839,6 +893,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
spec_augment: Optional[SpecAugment] = None,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute loss given the model and its inputs. Compute loss given the model and its inputs.
@ -855,8 +910,8 @@ def compute_loss(
True for training. False for validation. When it is True, this True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it function enables autograd during computation; when it is False, it
disables autograd. disables autograd.
warmup: a floating point value which increases throughout training; spec_augment:
values >= 1.0 are fully warmed up and have all modules present. The SpecAugment instance used only when use_cr_ctc is True.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
@ -874,14 +929,36 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
use_cr_ctc = params.use_cr_ctc
use_sr_ctc = params.use_sr_ctc
use_spec_aug = use_cr_ctc and is_training
if use_spec_aug:
supervision_intervals = batch["supervisions"]
supervision_segments = torch.stack(
[
supervision_intervals["sequence_idx"],
supervision_intervals["start_frame"],
supervision_intervals["num_frames"],
],
dim=1,
) # shape: (S, 3)
else:
supervision_segments = None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
use_cr_ctc=use_cr_ctc,
use_sr_ctc=use_sr_ctc,
use_spec_aug=use_spec_aug,
spec_augment=spec_augment,
supervision_segments=supervision_segments,
time_warp_factor=params.spec_aug_time_warp_factor,
) )
loss = 0.0 loss = 0.0
@ -904,6 +981,10 @@ def compute_loss(
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss
if use_cr_ctc:
loss += params.cr_loss_scale * cr_loss
if use_sr_ctc:
loss += params.sr_loss_scale * sr_loss
if params.use_attention_decoder: if params.use_attention_decoder:
loss += params.attention_decoder_loss_scale * attention_decoder_loss loss += params.attention_decoder_loss_scale * attention_decoder_loss
@ -922,6 +1003,10 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc: if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item() info["ctc_loss"] = ctc_loss.detach().cpu().item()
if use_cr_ctc:
info["cr_loss"] = cr_loss.detach().cpu().item()
if use_sr_ctc:
info["sr_loss"] = sr_loss.detach().cpu().item()
if params.use_attention_decoder: if params.use_attention_decoder:
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
@ -971,6 +1056,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
spec_augment: Optional[SpecAugment] = None,
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -997,6 +1083,8 @@ def train_one_epoch(
Dataloader for the validation dataset. Dataloader for the validation dataset.
scaler: scaler:
The scaler used for mix precision training. The scaler used for mix precision training.
spec_augment:
The SpecAugment instance used only when use_cr_ctc is True.
model_avg: model_avg:
The stored model averaged from the start of training. The stored model averaged from the start of training.
tb_writer: tb_writer:
@ -1043,6 +1131,7 @@ def train_one_epoch(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
spec_augment=spec_augment,
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -1238,6 +1327,13 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
if params.use_cr_ctc:
assert params.use_ctc
assert not params.enable_spec_aug # we will do spec_augment in model.py
spec_augment = get_spec_augment(params)
else:
spec_augment = None
assert params.save_every_n >= params.average_period assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None model_avg: Optional[nn.Module] = None
if rank == 0: if rank == 0:
@ -1360,6 +1456,7 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
params=params, params=params,
spec_augment=spec_augment,
) )
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
@ -1387,6 +1484,7 @@ def run(rank, world_size, args):
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler, scaler=scaler,
spec_augment=spec_augment,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
@ -1452,6 +1550,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
params: AttributeDict, params: AttributeDict,
spec_augment: Optional[SpecAugment] = None,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -1471,6 +1570,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
spec_augment=spec_augment,
) )
loss.backward() loss.backward()
optimizer.zero_grad() optimizer.zero_grad()

View File

@ -21,6 +21,7 @@ import argparse
import collections import collections
import logging import logging
import os import os
import random
import re import re
import subprocess import subprocess
from collections import defaultdict from collections import defaultdict
@ -38,6 +39,7 @@ import sentencepiece as spm
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
from pypinyin import lazy_pinyin, pinyin from pypinyin import lazy_pinyin, pinyin
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -2271,3 +2273,41 @@ def num_tokens(
if 0 in ans: if 0 in ans:
num_tokens -= 1 num_tokens -= 1
return num_tokens return num_tokens
# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py
def time_warp(
features: torch.Tensor,
p: float = 0.9,
time_warp_factor: Optional[int] = 80,
supervision_segments: Optional[torch.Tensor] = None,
):
"""Apply time warping on a batch of features
"""
if time_warp_factor is None or time_warp_factor < 1:
return features
assert len(features.shape) == 3, (
"SpecAugment only supports batches of single-channel feature matrices."
)
features = features.clone()
if supervision_segments is None:
# No supervisions - apply spec augment to full feature matrices.
for sequence_idx in range(features.size(0)):
if random.random() > p:
# Randomly choose whether this transform is applied
continue
features[sequence_idx] = time_warp_impl(
features[sequence_idx], factor=time_warp_factor
)
else:
# Supervisions provided - we will apply time warping only on the supervised areas.
for sequence_idx, start_frame, num_frames in supervision_segments:
if random.random() > p:
# Randomly choose whether this transform is applied
continue
end_frame = start_frame + num_frames
features[sequence_idx, start_frame:end_frame] = time_warp_impl(
features[sequence_idx, start_frame:end_frame], factor=time_warp_factor
)
return features