From ebbbcbcaf165ffd95e1d7df4fc873e356e8aacd9 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 4 Sep 2024 14:27:25 +0800 Subject: [PATCH] support consistency-regularized CTC --- egs/librispeech/ASR/zipformer/model.py | 171 +++++++++- egs/librispeech/ASR/zipformer/spec_augment.py | 313 ++++++++++++++++++ egs/librispeech/ASR/zipformer/train.py | 85 ++++- 3 files changed, 556 insertions(+), 13 deletions(-) create mode 100644 egs/librispeech/ASR/zipformer/spec_augment.py diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index bd1ed26d8..cf935d835 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -25,6 +25,7 @@ from encoder_interface import EncoderInterface from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask +from spec_augment import SpecAugment, time_warp class AsrModel(nn.Module): @@ -181,6 +182,63 @@ class AsrModel(nn.Module): ) return ctc_loss + def forward_cr_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + time_mask: Optional[torch.Tensor] = None, + cr_loss_masked_scale: float = 3.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute CTC loss with consistency regularization loss. + Args: + encoder_out: + Encoder output, of shape (2 * N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (2 * N,). + targets: + Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + time_mask: + Downsampled time masks of shape (2 * N, T, 1). + cr_loss_masked_scale: + The loss scale used to scale up the cr_loss at masked positions. + """ + # Compute CTC loss + ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) + targets=targets.cpu(), + input_lengths=encoder_out_lens.cpu(), + target_lengths=target_lengths.cpu(), + reduction="sum", + ) + + # Compute consistency regularization loss + exchanged_targets = ctc_output.detach().chunk(2, dim=0) + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + cr_loss = nn.functional.kl_div( + input=ctc_output, + target=exchanged_targets, + reduction="none", + log_target=True, + ) # (2 * N, T, C) + if time_mask is not None: + assert time_mask.shape[:-1] == ctc_output.shape[:-1], ( + time_mask.shape, ctc_output.shape + ) + masked_scale = time_mask * (cr_loss_masked_scale - 1) + 1 + # e.g., if cr_loss_masked_scale = 3, scales at masked positions are 3, + # scales at unmasked positions are 1 + cr_loss = cr_loss * masked_scale # scaling up masked positions + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) + cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + + return ctc_loss, cr_loss + def forward_transducer( self, encoder_out: torch.Tensor, @@ -296,7 +354,13 @@ class AsrModel(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + use_cr_ctc: bool = False, + use_spec_aug: bool = False, + spec_augment: Optional[SpecAugment] = None, + supervision_segments: Optional[torch.Tensor] = None, + time_warp_factor: Optional[int] = 80, + cr_loss_masked_scale: float = 3.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -316,9 +380,28 @@ class AsrModel(nn.Module): lm_scale: The scale to smooth the loss with lm (output of predictor network) part + use_cr_ctc: + Whether use consistency-regularized CTC. + use_spec_aug: + Whether apply spec-augment manually, used only if use_cr_ctc is True. + spec_augment: + The SpecAugment instance that returns time masks, + used only if use_cr_ctc is True. + supervision_segments: + An int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features``. + Used only if use_cr_ctc is True. + time_warp_factor: + Parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + Used only if use_cr_ctc is True. + cr_loss_masked_scale: + The loss scale used to scale up the cr_loss at masked positions. + Returns: - Return the transducer losses and CTC loss, - in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss) + Return the transducer losses, CTC loss, AED loss, + and consistency-regularization loss in form of + (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -334,6 +417,27 @@ class AsrModel(nn.Module): device = x.device + if use_cr_ctc: + assert self.use_ctc + if use_spec_aug: + assert spec_augment is not None and spec_augment.time_warp_factor < 1 + # Apply time warping before input duplicating + assert supervision_segments is not None + x = time_warp( + x, + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments, + ) + # Independently apply frequency masking and time masking to the two copies + x, time_mask = spec_augment(x.repeat(2, 1, 1)) + # time_mask: 1 for masked, 0 for unmasked + time_mask = downsample_time_mask(time_mask, x.dtype) + else: + x = x.repeat(2, 1, 1) + time_mask = None + x_lens = x_lens.repeat(2) + y = k2.ragged.cat([y, y], axis=0) + # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) @@ -351,6 +455,9 @@ class AsrModel(nn.Module): am_scale=am_scale, lm_scale=lm_scale, ) + if use_cr_ctc: + simple_loss = simple_loss * 0.5 + pruned_loss = pruned_loss * 0.5 else: simple_loss = torch.empty(0) pruned_loss = torch.empty(0) @@ -358,14 +465,28 @@ class AsrModel(nn.Module): if self.use_ctc: # Compute CTC loss targets = y.values - ctc_loss = self.forward_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) + if not use_cr_ctc: + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + cr_loss = torch.empty(0) + else: + ctc_loss, cr_loss = self.forward_cr_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + time_mask=time_mask, + cr_loss_masked_scale=cr_loss_masked_scale, + ) + ctc_loss = ctc_loss * 0.5 + cr_loss = cr_loss * 0.5 else: ctc_loss = torch.empty(0) + cr_loss = torch.empty(0) if self.use_attention_decoder: attention_decoder_loss = self.attention_decoder.calc_att_loss( @@ -374,7 +495,37 @@ class AsrModel(nn.Module): ys=y.to(device), ys_lens=y_lens.to(device), ) + if use_cr_ctc: + attention_decoder_loss = attention_decoder_loss * 0.5 else: attention_decoder_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss + + +def downsample_time_mask(time_mask: torch.Tensor, dtype: torch.dtype): + """Downsample the time masks as in Zipformer. + Args: + time_mask: shape of (N, T) + Returns: + The downsampled time masks of shape (N, T', 1), + where T' = ((T - 7) // 2 + 1) // 2 + """ + # Downsample the time masks as in Zipformer + time_mask = time_mask.to(dtype).unsqueeze(dim=1) + # as in conv-embed + time_mask = nn.functional.max_pool1d( + time_mask, kernel_size=3, stride=1, padding=0 + ) # T - 2 + time_mask = nn.functional.max_pool1d( + time_mask, kernel_size=3, stride=2, padding=0 + ) # (T - 3) // 2 + time_mask = nn.functional.max_pool1d( + time_mask, kernel_size=3, stride=1, padding=0 + ) # (T - 7) // 2 + # as in output-downsampling + time_mask = nn.functional.max_pool1d( + time_mask, kernel_size=2, stride=2, padding=0, ceil_mode=True + ) + time_mask = time_mask.transpose(1, 2) # (N * 2, T', 1) + return time_mask diff --git a/egs/librispeech/ASR/zipformer/spec_augment.py b/egs/librispeech/ASR/zipformer/spec_augment.py new file mode 100644 index 000000000..6ddf2b09b --- /dev/null +++ b/egs/librispeech/ASR/zipformer/spec_augment.py @@ -0,0 +1,313 @@ +# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copied from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py +# with minor modification for cr-ctc training. + + +import math +import random +from typing import Any, Dict, Optional, Tuple + +import torch +from lhotse.dataset.signal_transforms import time_warp as time_warp_impl + + +class SpecAugment(torch.nn.Module): + """SpecAugment from lhotse with minor modification, returning time masks. + + SpecAugment performs three augmentations: + - time warping of the feature matrix + - masking of ranges of features (frequency bands) + - masking of ranges of frames (time) + + The current implementation works with batches, but processes each example separately + in a loop rather than simultaneously to achieve different augmentation parameters for + each example. + """ + + def __init__( + self, + time_warp_factor: Optional[int] = 80, + num_feature_masks: int = 2, + features_mask_size: int = 27, + num_frame_masks: int = 10, + frames_mask_size: int = 100, + max_frames_mask_fraction: float = 0.15, + p=0.9, + ): + """ + SpecAugment's constructor. + + :param time_warp_factor: parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + :param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable. + :param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins). + This is the ``F`` parameter from the SpecAugment paper. + :param num_frame_masks: the number of masking regions for utterances. Set to ``0`` to disable. + :param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames). + This is the ``T`` parameter from the SpecAugment paper. + :param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length + of the utterance (or supervision segment). + This is the parameter denoted by ``p`` in the SpecAugment paper. + :param p: the probability of applying this transform. + It is different from ``p`` in the SpecAugment paper! + """ + super().__init__() + assert 0 <= p <= 1 + assert num_feature_masks >= 0 + assert num_frame_masks >= 0 + assert features_mask_size > 0 + assert frames_mask_size > 0 + self.time_warp_factor = time_warp_factor + self.num_feature_masks = num_feature_masks + self.features_mask_size = features_mask_size + self.num_frame_masks = num_frame_masks + self.frames_mask_size = frames_mask_size + self.max_frames_mask_fraction = max_frames_mask_fraction + self.p = p + + def forward( + self, + features: torch.Tensor, + supervision_segments: Optional[torch.IntTensor] = None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes SpecAugment for a batch of feature matrices. + + Since the batch will usually already be padded, the user can optionally + provide a ``supervision_segments`` tensor that will be used to apply SpecAugment + only to selected areas of the input. The format of this input is described below. + + :param features: a batch of feature matrices with shape ``(B, T, F)``. + :param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features`` -- there may be either + less or more than the batch size. + The second dimension encoder three kinds of information: + the sequence index of the corresponding feature matrix in `features`, + the start frame index, and the number of frames for each segment. + :return: + - an augmented tensor of shape ``(B, T, F)``. + - the corresponding time masks of shape ``(B, T)``. + """ + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of " "single-channel feature matrices." + ) + features = features.clone() + + time_masks = [] + + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + masked_feature, time_mask = self._forward_single(features[sequence_idx]) + features[sequence_idx] = masked_feature + time_masks.append(time_mask) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + end_frame = start_frame + num_frames + warped_feature, _ = self._forward_single( + features[sequence_idx, start_frame:end_frame], warp=True, mask=False + ) + features[sequence_idx, start_frame:end_frame] = warped_feature + # ... and then time-mask the full feature matrices. Note that in this mode, + # it might happen that masks are applied to different sequences/examples + # than the time warping. + for sequence_idx in range(features.size(0)): + masked_feature, time_mask = self._forward_single( + features[sequence_idx], warp=False, mask=True + ) + features[sequence_idx] = masked_feature + time_masks.append(time_mask) + + time_masks = torch.cat(time_masks, dim=0) + assert time_masks.shape == features.shape[:-1], (time_masks.shape == features.shape[:-1]) + return features, time_masks + + def _forward_single( + self, features: torch.Tensor, warp: bool = True, mask: bool = True + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply SpecAugment to a single feature matrix of shape (T, F). + """ + if random.random() > self.p: + # Randomly choose whether this transform is applied + time_mask = torch.zeros( + 1, features.size(0), dtype=torch.bool, device=features.device + ) + return features, time_mask + + time_mask = None + if warp: + if self.time_warp_factor is not None and self.time_warp_factor >= 1: + features = time_warp_impl(features, factor=self.time_warp_factor) + + if mask: + mean = features.mean() + # Frequency masking + features, _ = mask_along_axis_optimized( + features, + mask_size=self.features_mask_size, + mask_times=self.num_feature_masks, + mask_value=mean, + axis=2, + ) + # Time masking + max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) + num_frame_masks = min( + self.num_frame_masks, + math.ceil(max_tot_mask_frames / self.frames_mask_size), + ) + max_mask_frames = min( + self.frames_mask_size, max_tot_mask_frames // num_frame_masks + ) + features, time_mask = mask_along_axis_optimized( + features, + mask_size=max_mask_frames, + mask_times=num_frame_masks, + mask_value=mean, + axis=1, + return_time_mask=True, + ) + + return features, time_mask + + def state_dict(self, **kwargs) -> Dict[str, Any]: + return dict( + time_warp_factor=self.time_warp_factor, + num_feature_masks=self.num_feature_masks, + features_mask_size=self.features_mask_size, + num_frame_masks=self.num_frame_masks, + frames_mask_size=self.frames_mask_size, + max_frames_mask_fraction=self.max_frames_mask_fraction, + p=self.p, + ) + + def load_state_dict(self, state_dict: Dict[str, Any]): + self.time_warp_factor = state_dict.get( + "time_warp_factor", self.time_warp_factor + ) + self.num_feature_masks = state_dict.get( + "num_feature_masks", self.num_feature_masks + ) + self.features_mask_size = state_dict.get( + "features_mask_size", self.features_mask_size + ) + self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) + self.frames_mask_size = state_dict.get( + "frames_mask_size", self.frames_mask_size + ) + self.max_frames_mask_fraction = state_dict.get( + "max_frames_mask_fraction", self.max_frames_mask_fraction + ) + self.p = state_dict.get("p", self.p) + + +def mask_along_axis_optimized( + features: torch.Tensor, + mask_size: int, + mask_times: int, + mask_value: float, + axis: int, + return_time_mask: bool = False, +) -> torch.Tensor: + """ + Apply Frequency and Time masking along axis. + Frequency and Time masking as described in the SpecAugment paper. + + :param features: input tensor of shape ``(T, F)`` + :mask_size: the width size for masking. + :mask_times: the number of masking regions. + :mask_value: Value to assign to the masked regions. + :axis: Axis to apply masking on (1 -> time, 2 -> frequency) + :return_time_mask: Whether return the time mask of shape ``(1, T)`` + """ + if axis not in [1, 2]: + raise ValueError("Only Frequency and Time masking are supported!") + + if return_time_mask and axis == 1: + time_mask = torch.zeros( + 1, features.size(0), dtype=torch.bool, device=features.device + ) + else: + time_mask = None + + features = features.unsqueeze(0) + features = features.reshape([-1] + list(features.size()[-2:])) + + values = torch.randint(int(0), int(mask_size), (1, mask_times)) + min_values = torch.rand(1, mask_times) * (features.size(axis) - values) + mask_starts = (min_values.long()).squeeze() + mask_ends = (min_values.long() + values.long()).squeeze() + + if axis == 1: + if mask_times == 1: + features[:, mask_starts:mask_ends] = mask_value + if return_time_mask: + time_mask[:, mask_starts:mask_ends] = True + return features.squeeze(0), time_mask + for (mask_start, mask_end) in zip(mask_starts, mask_ends): + features[:, mask_start:mask_end] = mask_value + if return_time_mask: + time_mask[:, mask_start:mask_end] = True + else: + if mask_times == 1: + features[:, :, mask_starts:mask_ends] = mask_value + return features.squeeze(0), time_mask + for (mask_start, mask_end) in zip(mask_starts, mask_ends): + features[:, :, mask_start:mask_end] = mask_value + + features = features.squeeze(0) + return features, time_mask + + +def time_warp( + features: torch.Tensor, + p: float = 0.9, + time_warp_factor: Optional[int] = 80, + supervision_segments: Optional[torch.Tensor] = None, +): + if time_warp_factor is None or time_warp_factor < 1: + return features + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of single-channel feature matrices." + ) + features = features.clone() + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + if random.random() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx] = time_warp_impl( + features[sequence_idx], factor=time_warp_factor + ) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + if random.random() > p: + # Randomly choose whether this transform is applied + continue + end_frame = start_frame + num_frames + features[sequence_idx, start_frame:end_frame] = time_warp_impl( + features[sequence_idx, start_frame:end_frame], factor=time_warp_factor + ) + + return features diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9c1c7f5a7..328b3cfdd 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -102,6 +102,7 @@ from icefall.utils import ( setup_logger, str2bool, ) +from spec_augment import SpecAugment LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -304,6 +305,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use attention-decoder head.", ) + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -449,6 +457,13 @@ def get_parser(): help="Scale for CTC loss.", ) + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.1, + help="Scale for consistency-regularization loss.", + ) + parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -590,6 +605,11 @@ def get_params() -> AttributeDict: # parameters for attention-decoder "ignore_id": -1, "label_smoothing": 0.1, + # parameters used for CR-CTC + # When using cr-ctc, we increase the time-masking ratio. + "time_mask_ratio": 2.0, + # The scale used to scale up the cr_loss at masked positions. + "cr_loss_masked_scale": 3.0, "warm_step": 2000, "env_info": get_env_info(), } @@ -717,6 +737,24 @@ def get_model(params: AttributeDict) -> nn.Module: return model +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = 10 * params.time_mask_ratio + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -839,6 +877,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + spec_augment: Optional[SpecAugment] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -855,8 +894,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] @@ -874,14 +913,35 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, + cr_loss_masked_scale=params.cr_loss_masked_scale, ) loss = 0.0 @@ -904,6 +964,8 @@ def compute_loss( if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + loss += params.cr_loss_scale * cr_loss if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -922,6 +984,8 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() @@ -971,6 +1035,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -997,6 +1062,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -1043,6 +1110,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1238,6 +1306,13 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + if params.use_cr_ctc: + assert params.use_ctc + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1360,6 +1435,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) @@ -1387,6 +1463,7 @@ def run(rank, world_size, args): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, + spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1452,6 +1529,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1471,6 +1549,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad()