diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 2de1e08fe..deebb2a75 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -24,8 +24,8 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos, make_pad_mask -from spec_augment import SpecAugment, time_warp +from icefall.utils import add_sos, make_pad_mask, time_warp +from lhotse.dataset import SpecAugment class AsrModel(nn.Module): @@ -188,8 +188,6 @@ class AsrModel(nn.Module): encoder_out_lens: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, - time_mask: Optional[torch.Tensor] = None, - cr_loss_masked_scale: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute CTC loss with consistency regularization loss. Args: @@ -200,10 +198,6 @@ class AsrModel(nn.Module): targets: Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed to be un-padded and concatenated within 1 dimension. - time_mask: - Downsampled time masks of shape (2 * N, T, 1). - cr_loss_masked_scale: - The loss scale used to scale up the cr_loss at masked positions. """ # Compute CTC loss ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) @@ -226,14 +220,6 @@ class AsrModel(nn.Module): reduction="none", log_target=True, ) # (2 * N, T, C) - if time_mask is not None: - assert time_mask.shape[:-1] == ctc_output.shape[:-1], ( - time_mask.shape, ctc_output.shape - ) - masked_scale = time_mask * (cr_loss_masked_scale - 1) + 1 - # e.g., if cr_loss_masked_scale = 3, scales at masked positions are 3, - # scales at unmasked positions are 1 - cr_loss = cr_loss * masked_scale # scaling up masked positions length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() @@ -359,7 +345,6 @@ class AsrModel(nn.Module): spec_augment: Optional[SpecAugment] = None, supervision_segments: Optional[torch.Tensor] = None, time_warp_factor: Optional[int] = 80, - cr_loss_masked_scale: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -395,8 +380,6 @@ class AsrModel(nn.Module): Parameter for the time warping; larger values mean more warping. Set to ``None``, or less than ``1``, to disable. Used only if use_cr_ctc is True. - cr_loss_masked_scale: - The loss scale used to scale up the cr_loss at masked positions. Returns: Return the transducer losses, CTC loss, AED loss, @@ -429,12 +412,9 @@ class AsrModel(nn.Module): supervision_segments=supervision_segments, ) # Independently apply frequency masking and time masking to the two copies - x, time_mask = spec_augment(x.repeat(2, 1, 1)) - # time_mask: 1 for masked, 0 for unmasked - time_mask = downsample_time_mask(time_mask, x.dtype) + x = spec_augment(x.repeat(2, 1, 1)) else: x = x.repeat(2, 1, 1) - time_mask = None x_lens = x_lens.repeat(2) y = k2.ragged.cat([y, y], axis=0) @@ -479,8 +459,6 @@ class AsrModel(nn.Module): encoder_out_lens=encoder_out_lens, targets=targets, target_lengths=y_lens, - time_mask=time_mask, - cr_loss_masked_scale=cr_loss_masked_scale, ) ctc_loss = ctc_loss * 0.5 cr_loss = cr_loss * 0.5 @@ -501,31 +479,3 @@ class AsrModel(nn.Module): attention_decoder_loss = torch.empty(0) return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss - - -def downsample_time_mask(time_mask: torch.Tensor, dtype: torch.dtype): - """Downsample the time masks as in Zipformer. - Args: - time_mask: shape of (N, T) - Returns: - The downsampled time masks of shape (N, T', 1), - where T' = ((T - 7) // 2 + 1) // 2 - """ - # Downsample the time masks as in Zipformer - time_mask = time_mask.to(dtype).unsqueeze(dim=1) - # as in conv-embed - time_mask = nn.functional.max_pool1d( - time_mask, kernel_size=3, stride=1, padding=0 - ) # T - 2 - time_mask = nn.functional.max_pool1d( - time_mask, kernel_size=3, stride=2, padding=0 - ) # (T - 3) // 2 - time_mask = nn.functional.max_pool1d( - time_mask, kernel_size=3, stride=1, padding=0 - ) # (T - 7) // 2 - # as in output-downsampling - time_mask = nn.functional.max_pool1d( - time_mask, kernel_size=2, stride=2, padding=0, ceil_mode=True - ) - time_mask = time_mask.transpose(1, 2) # (N * 2, T', 1) - return time_mask diff --git a/egs/librispeech/ASR/zipformer/spec_augment.py b/egs/librispeech/ASR/zipformer/spec_augment.py deleted file mode 100644 index 6ddf2b09b..000000000 --- a/egs/librispeech/ASR/zipformer/spec_augment.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copied from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py -# with minor modification for cr-ctc training. - - -import math -import random -from typing import Any, Dict, Optional, Tuple - -import torch -from lhotse.dataset.signal_transforms import time_warp as time_warp_impl - - -class SpecAugment(torch.nn.Module): - """SpecAugment from lhotse with minor modification, returning time masks. - - SpecAugment performs three augmentations: - - time warping of the feature matrix - - masking of ranges of features (frequency bands) - - masking of ranges of frames (time) - - The current implementation works with batches, but processes each example separately - in a loop rather than simultaneously to achieve different augmentation parameters for - each example. - """ - - def __init__( - self, - time_warp_factor: Optional[int] = 80, - num_feature_masks: int = 2, - features_mask_size: int = 27, - num_frame_masks: int = 10, - frames_mask_size: int = 100, - max_frames_mask_fraction: float = 0.15, - p=0.9, - ): - """ - SpecAugment's constructor. - - :param time_warp_factor: parameter for the time warping; larger values mean more warping. - Set to ``None``, or less than ``1``, to disable. - :param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable. - :param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins). - This is the ``F`` parameter from the SpecAugment paper. - :param num_frame_masks: the number of masking regions for utterances. Set to ``0`` to disable. - :param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames). - This is the ``T`` parameter from the SpecAugment paper. - :param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length - of the utterance (or supervision segment). - This is the parameter denoted by ``p`` in the SpecAugment paper. - :param p: the probability of applying this transform. - It is different from ``p`` in the SpecAugment paper! - """ - super().__init__() - assert 0 <= p <= 1 - assert num_feature_masks >= 0 - assert num_frame_masks >= 0 - assert features_mask_size > 0 - assert frames_mask_size > 0 - self.time_warp_factor = time_warp_factor - self.num_feature_masks = num_feature_masks - self.features_mask_size = features_mask_size - self.num_frame_masks = num_frame_masks - self.frames_mask_size = frames_mask_size - self.max_frames_mask_fraction = max_frames_mask_fraction - self.p = p - - def forward( - self, - features: torch.Tensor, - supervision_segments: Optional[torch.IntTensor] = None, - *args, - **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Computes SpecAugment for a batch of feature matrices. - - Since the batch will usually already be padded, the user can optionally - provide a ``supervision_segments`` tensor that will be used to apply SpecAugment - only to selected areas of the input. The format of this input is described below. - - :param features: a batch of feature matrices with shape ``(B, T, F)``. - :param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of - supervision segments that exist in ``features`` -- there may be either - less or more than the batch size. - The second dimension encoder three kinds of information: - the sequence index of the corresponding feature matrix in `features`, - the start frame index, and the number of frames for each segment. - :return: - - an augmented tensor of shape ``(B, T, F)``. - - the corresponding time masks of shape ``(B, T)``. - """ - assert len(features.shape) == 3, ( - "SpecAugment only supports batches of " "single-channel feature matrices." - ) - features = features.clone() - - time_masks = [] - - if supervision_segments is None: - # No supervisions - apply spec augment to full feature matrices. - for sequence_idx in range(features.size(0)): - masked_feature, time_mask = self._forward_single(features[sequence_idx]) - features[sequence_idx] = masked_feature - time_masks.append(time_mask) - else: - # Supervisions provided - we will apply time warping only on the supervised areas. - for sequence_idx, start_frame, num_frames in supervision_segments: - end_frame = start_frame + num_frames - warped_feature, _ = self._forward_single( - features[sequence_idx, start_frame:end_frame], warp=True, mask=False - ) - features[sequence_idx, start_frame:end_frame] = warped_feature - # ... and then time-mask the full feature matrices. Note that in this mode, - # it might happen that masks are applied to different sequences/examples - # than the time warping. - for sequence_idx in range(features.size(0)): - masked_feature, time_mask = self._forward_single( - features[sequence_idx], warp=False, mask=True - ) - features[sequence_idx] = masked_feature - time_masks.append(time_mask) - - time_masks = torch.cat(time_masks, dim=0) - assert time_masks.shape == features.shape[:-1], (time_masks.shape == features.shape[:-1]) - return features, time_masks - - def _forward_single( - self, features: torch.Tensor, warp: bool = True, mask: bool = True - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply SpecAugment to a single feature matrix of shape (T, F). - """ - if random.random() > self.p: - # Randomly choose whether this transform is applied - time_mask = torch.zeros( - 1, features.size(0), dtype=torch.bool, device=features.device - ) - return features, time_mask - - time_mask = None - if warp: - if self.time_warp_factor is not None and self.time_warp_factor >= 1: - features = time_warp_impl(features, factor=self.time_warp_factor) - - if mask: - mean = features.mean() - # Frequency masking - features, _ = mask_along_axis_optimized( - features, - mask_size=self.features_mask_size, - mask_times=self.num_feature_masks, - mask_value=mean, - axis=2, - ) - # Time masking - max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) - num_frame_masks = min( - self.num_frame_masks, - math.ceil(max_tot_mask_frames / self.frames_mask_size), - ) - max_mask_frames = min( - self.frames_mask_size, max_tot_mask_frames // num_frame_masks - ) - features, time_mask = mask_along_axis_optimized( - features, - mask_size=max_mask_frames, - mask_times=num_frame_masks, - mask_value=mean, - axis=1, - return_time_mask=True, - ) - - return features, time_mask - - def state_dict(self, **kwargs) -> Dict[str, Any]: - return dict( - time_warp_factor=self.time_warp_factor, - num_feature_masks=self.num_feature_masks, - features_mask_size=self.features_mask_size, - num_frame_masks=self.num_frame_masks, - frames_mask_size=self.frames_mask_size, - max_frames_mask_fraction=self.max_frames_mask_fraction, - p=self.p, - ) - - def load_state_dict(self, state_dict: Dict[str, Any]): - self.time_warp_factor = state_dict.get( - "time_warp_factor", self.time_warp_factor - ) - self.num_feature_masks = state_dict.get( - "num_feature_masks", self.num_feature_masks - ) - self.features_mask_size = state_dict.get( - "features_mask_size", self.features_mask_size - ) - self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) - self.frames_mask_size = state_dict.get( - "frames_mask_size", self.frames_mask_size - ) - self.max_frames_mask_fraction = state_dict.get( - "max_frames_mask_fraction", self.max_frames_mask_fraction - ) - self.p = state_dict.get("p", self.p) - - -def mask_along_axis_optimized( - features: torch.Tensor, - mask_size: int, - mask_times: int, - mask_value: float, - axis: int, - return_time_mask: bool = False, -) -> torch.Tensor: - """ - Apply Frequency and Time masking along axis. - Frequency and Time masking as described in the SpecAugment paper. - - :param features: input tensor of shape ``(T, F)`` - :mask_size: the width size for masking. - :mask_times: the number of masking regions. - :mask_value: Value to assign to the masked regions. - :axis: Axis to apply masking on (1 -> time, 2 -> frequency) - :return_time_mask: Whether return the time mask of shape ``(1, T)`` - """ - if axis not in [1, 2]: - raise ValueError("Only Frequency and Time masking are supported!") - - if return_time_mask and axis == 1: - time_mask = torch.zeros( - 1, features.size(0), dtype=torch.bool, device=features.device - ) - else: - time_mask = None - - features = features.unsqueeze(0) - features = features.reshape([-1] + list(features.size()[-2:])) - - values = torch.randint(int(0), int(mask_size), (1, mask_times)) - min_values = torch.rand(1, mask_times) * (features.size(axis) - values) - mask_starts = (min_values.long()).squeeze() - mask_ends = (min_values.long() + values.long()).squeeze() - - if axis == 1: - if mask_times == 1: - features[:, mask_starts:mask_ends] = mask_value - if return_time_mask: - time_mask[:, mask_starts:mask_ends] = True - return features.squeeze(0), time_mask - for (mask_start, mask_end) in zip(mask_starts, mask_ends): - features[:, mask_start:mask_end] = mask_value - if return_time_mask: - time_mask[:, mask_start:mask_end] = True - else: - if mask_times == 1: - features[:, :, mask_starts:mask_ends] = mask_value - return features.squeeze(0), time_mask - for (mask_start, mask_end) in zip(mask_starts, mask_ends): - features[:, :, mask_start:mask_end] = mask_value - - features = features.squeeze(0) - return features, time_mask - - -def time_warp( - features: torch.Tensor, - p: float = 0.9, - time_warp_factor: Optional[int] = 80, - supervision_segments: Optional[torch.Tensor] = None, -): - if time_warp_factor is None or time_warp_factor < 1: - return features - assert len(features.shape) == 3, ( - "SpecAugment only supports batches of single-channel feature matrices." - ) - features = features.clone() - if supervision_segments is None: - # No supervisions - apply spec augment to full feature matrices. - for sequence_idx in range(features.size(0)): - if random.random() > p: - # Randomly choose whether this transform is applied - continue - features[sequence_idx] = time_warp_impl( - features[sequence_idx], factor=time_warp_factor - ) - else: - # Supervisions provided - we will apply time warping only on the supervised areas. - for sequence_idx, start_frame, num_frames in supervision_segments: - if random.random() > p: - # Randomly choose whether this transform is applied - continue - end_frame = start_frame + num_frames - features[sequence_idx, start_frame:end_frame] = time_warp_impl( - features[sequence_idx, start_frame:end_frame], factor=time_warp_factor - ) - - return features diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 3fde55de2..3a8995c81 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -72,6 +72,7 @@ from attention_decoder import AttentionDecoderModel from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset import SpecAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -102,7 +103,6 @@ from icefall.utils import ( setup_logger, str2bool, ) -from spec_augment import SpecAugment LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -460,22 +460,15 @@ def get_parser(): parser.add_argument( "--cr-loss-scale", type=float, - default=0.15, + default=0.2, help="Scale for consistency-regularization loss.", ) parser.add_argument( "--time-mask-ratio", type=float, - default=2.0, - help="When using cr-ctc, we increase the time-masking ratio.", - ) - - parser.add_argument( - "--cr-loss-masked-scale", - type=float, - default=1.0, - help="The value used to scale up the cr_loss at masked positions", + default=2.5, + help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", ) parser.add_argument( @@ -950,7 +943,6 @@ def compute_loss( spec_augment=spec_augment, supervision_segments=supervision_segments, time_warp_factor=params.spec_aug_time_warp_factor, - cr_loss_masked_scale=params.cr_loss_masked_scale, ) loss = 0.0 diff --git a/icefall/utils.py b/icefall/utils.py index 1dbb954de..b0a42cefa 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -21,6 +21,7 @@ import argparse import collections import logging import os +import random import re import subprocess from collections import defaultdict @@ -38,6 +39,7 @@ import sentencepiece as spm import torch import torch.distributed as dist import torch.nn as nn +from lhotse.dataset.signal_transforms import time_warp as time_warp_impl from pypinyin import lazy_pinyin, pinyin from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from torch.utils.tensorboard import SummaryWriter @@ -2271,3 +2273,41 @@ def num_tokens( if 0 in ans: num_tokens -= 1 return num_tokens + + +# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py +def time_warp( + features: torch.Tensor, + p: float = 0.9, + time_warp_factor: Optional[int] = 80, + supervision_segments: Optional[torch.Tensor] = None, +): + """Apply time warping on a batch of features + """ + if time_warp_factor is None or time_warp_factor < 1: + return features + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of single-channel feature matrices." + ) + features = features.clone() + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + if random.random() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx] = time_warp_impl( + features[sequence_idx], factor=time_warp_factor + ) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + if random.random() > p: + # Randomly choose whether this transform is applied + continue + end_frame = start_frame + num_frames + features[sequence_idx, start_frame:end_frame] = time_warp_impl( + features[sequence_idx, start_frame:end_frame], factor=time_warp_factor + ) + + return features