refactor codes

This commit is contained in:
yaozengwei 2024-10-08 00:34:32 +08:00
parent a6eead6c98
commit ae59e5d61e
4 changed files with 47 additions and 378 deletions

View File

@ -24,8 +24,8 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask, time_warp
from spec_augment import SpecAugment, time_warp from lhotse.dataset import SpecAugment
class AsrModel(nn.Module): class AsrModel(nn.Module):
@ -188,8 +188,6 @@ class AsrModel(nn.Module):
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
targets: torch.Tensor, targets: torch.Tensor,
target_lengths: torch.Tensor, target_lengths: torch.Tensor,
time_mask: Optional[torch.Tensor] = None,
cr_loss_masked_scale: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss. """Compute CTC loss with consistency regularization loss.
Args: Args:
@ -200,10 +198,6 @@ class AsrModel(nn.Module):
targets: targets:
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension. to be un-padded and concatenated within 1 dimension.
time_mask:
Downsampled time masks of shape (2 * N, T, 1).
cr_loss_masked_scale:
The loss scale used to scale up the cr_loss at masked positions.
""" """
# Compute CTC loss # Compute CTC loss
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
@ -226,14 +220,6 @@ class AsrModel(nn.Module):
reduction="none", reduction="none",
log_target=True, log_target=True,
) # (2 * N, T, C) ) # (2 * N, T, C)
if time_mask is not None:
assert time_mask.shape[:-1] == ctc_output.shape[:-1], (
time_mask.shape, ctc_output.shape
)
masked_scale = time_mask * (cr_loss_masked_scale - 1) + 1
# e.g., if cr_loss_masked_scale = 3, scales at masked positions are 3,
# scales at unmasked positions are 1
cr_loss = cr_loss * masked_scale # scaling up masked positions
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
@ -359,7 +345,6 @@ class AsrModel(nn.Module):
spec_augment: Optional[SpecAugment] = None, spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None, supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80, time_warp_factor: Optional[int] = 80,
cr_loss_masked_scale: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -395,8 +380,6 @@ class AsrModel(nn.Module):
Parameter for the time warping; larger values mean more warping. Parameter for the time warping; larger values mean more warping.
Set to ``None``, or less than ``1``, to disable. Set to ``None``, or less than ``1``, to disable.
Used only if use_cr_ctc is True. Used only if use_cr_ctc is True.
cr_loss_masked_scale:
The loss scale used to scale up the cr_loss at masked positions.
Returns: Returns:
Return the transducer losses, CTC loss, AED loss, Return the transducer losses, CTC loss, AED loss,
@ -429,12 +412,9 @@ class AsrModel(nn.Module):
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
) )
# Independently apply frequency masking and time masking to the two copies # Independently apply frequency masking and time masking to the two copies
x, time_mask = spec_augment(x.repeat(2, 1, 1)) x = spec_augment(x.repeat(2, 1, 1))
# time_mask: 1 for masked, 0 for unmasked
time_mask = downsample_time_mask(time_mask, x.dtype)
else: else:
x = x.repeat(2, 1, 1) x = x.repeat(2, 1, 1)
time_mask = None
x_lens = x_lens.repeat(2) x_lens = x_lens.repeat(2)
y = k2.ragged.cat([y, y], axis=0) y = k2.ragged.cat([y, y], axis=0)
@ -479,8 +459,6 @@ class AsrModel(nn.Module):
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
targets=targets, targets=targets,
target_lengths=y_lens, target_lengths=y_lens,
time_mask=time_mask,
cr_loss_masked_scale=cr_loss_masked_scale,
) )
ctc_loss = ctc_loss * 0.5 ctc_loss = ctc_loss * 0.5
cr_loss = cr_loss * 0.5 cr_loss = cr_loss * 0.5
@ -501,31 +479,3 @@ class AsrModel(nn.Module):
attention_decoder_loss = torch.empty(0) attention_decoder_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss
def downsample_time_mask(time_mask: torch.Tensor, dtype: torch.dtype):
"""Downsample the time masks as in Zipformer.
Args:
time_mask: shape of (N, T)
Returns:
The downsampled time masks of shape (N, T', 1),
where T' = ((T - 7) // 2 + 1) // 2
"""
# Downsample the time masks as in Zipformer
time_mask = time_mask.to(dtype).unsqueeze(dim=1)
# as in conv-embed
time_mask = nn.functional.max_pool1d(
time_mask, kernel_size=3, stride=1, padding=0
) # T - 2
time_mask = nn.functional.max_pool1d(
time_mask, kernel_size=3, stride=2, padding=0
) # (T - 3) // 2
time_mask = nn.functional.max_pool1d(
time_mask, kernel_size=3, stride=1, padding=0
) # (T - 7) // 2
# as in output-downsampling
time_mask = nn.functional.max_pool1d(
time_mask, kernel_size=2, stride=2, padding=0, ceil_mode=True
)
time_mask = time_mask.transpose(1, 2) # (N * 2, T', 1)
return time_mask

View File

@ -1,313 +0,0 @@
# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copied from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py
# with minor modification for cr-ctc training.
import math
import random
from typing import Any, Dict, Optional, Tuple
import torch
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
class SpecAugment(torch.nn.Module):
"""SpecAugment from lhotse with minor modification, returning time masks.
SpecAugment performs three augmentations:
- time warping of the feature matrix
- masking of ranges of features (frequency bands)
- masking of ranges of frames (time)
The current implementation works with batches, but processes each example separately
in a loop rather than simultaneously to achieve different augmentation parameters for
each example.
"""
def __init__(
self,
time_warp_factor: Optional[int] = 80,
num_feature_masks: int = 2,
features_mask_size: int = 27,
num_frame_masks: int = 10,
frames_mask_size: int = 100,
max_frames_mask_fraction: float = 0.15,
p=0.9,
):
"""
SpecAugment's constructor.
:param time_warp_factor: parameter for the time warping; larger values mean more warping.
Set to ``None``, or less than ``1``, to disable.
:param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable.
:param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins).
This is the ``F`` parameter from the SpecAugment paper.
:param num_frame_masks: the number of masking regions for utterances. Set to ``0`` to disable.
:param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames).
This is the ``T`` parameter from the SpecAugment paper.
:param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length
of the utterance (or supervision segment).
This is the parameter denoted by ``p`` in the SpecAugment paper.
:param p: the probability of applying this transform.
It is different from ``p`` in the SpecAugment paper!
"""
super().__init__()
assert 0 <= p <= 1
assert num_feature_masks >= 0
assert num_frame_masks >= 0
assert features_mask_size > 0
assert frames_mask_size > 0
self.time_warp_factor = time_warp_factor
self.num_feature_masks = num_feature_masks
self.features_mask_size = features_mask_size
self.num_frame_masks = num_frame_masks
self.frames_mask_size = frames_mask_size
self.max_frames_mask_fraction = max_frames_mask_fraction
self.p = p
def forward(
self,
features: torch.Tensor,
supervision_segments: Optional[torch.IntTensor] = None,
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes SpecAugment for a batch of feature matrices.
Since the batch will usually already be padded, the user can optionally
provide a ``supervision_segments`` tensor that will be used to apply SpecAugment
only to selected areas of the input. The format of this input is described below.
:param features: a batch of feature matrices with shape ``(B, T, F)``.
:param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of
supervision segments that exist in ``features`` -- there may be either
less or more than the batch size.
The second dimension encoder three kinds of information:
the sequence index of the corresponding feature matrix in `features`,
the start frame index, and the number of frames for each segment.
:return:
- an augmented tensor of shape ``(B, T, F)``.
- the corresponding time masks of shape ``(B, T)``.
"""
assert len(features.shape) == 3, (
"SpecAugment only supports batches of " "single-channel feature matrices."
)
features = features.clone()
time_masks = []
if supervision_segments is None:
# No supervisions - apply spec augment to full feature matrices.
for sequence_idx in range(features.size(0)):
masked_feature, time_mask = self._forward_single(features[sequence_idx])
features[sequence_idx] = masked_feature
time_masks.append(time_mask)
else:
# Supervisions provided - we will apply time warping only on the supervised areas.
for sequence_idx, start_frame, num_frames in supervision_segments:
end_frame = start_frame + num_frames
warped_feature, _ = self._forward_single(
features[sequence_idx, start_frame:end_frame], warp=True, mask=False
)
features[sequence_idx, start_frame:end_frame] = warped_feature
# ... and then time-mask the full feature matrices. Note that in this mode,
# it might happen that masks are applied to different sequences/examples
# than the time warping.
for sequence_idx in range(features.size(0)):
masked_feature, time_mask = self._forward_single(
features[sequence_idx], warp=False, mask=True
)
features[sequence_idx] = masked_feature
time_masks.append(time_mask)
time_masks = torch.cat(time_masks, dim=0)
assert time_masks.shape == features.shape[:-1], (time_masks.shape == features.shape[:-1])
return features, time_masks
def _forward_single(
self, features: torch.Tensor, warp: bool = True, mask: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply SpecAugment to a single feature matrix of shape (T, F).
"""
if random.random() > self.p:
# Randomly choose whether this transform is applied
time_mask = torch.zeros(
1, features.size(0), dtype=torch.bool, device=features.device
)
return features, time_mask
time_mask = None
if warp:
if self.time_warp_factor is not None and self.time_warp_factor >= 1:
features = time_warp_impl(features, factor=self.time_warp_factor)
if mask:
mean = features.mean()
# Frequency masking
features, _ = mask_along_axis_optimized(
features,
mask_size=self.features_mask_size,
mask_times=self.num_feature_masks,
mask_value=mean,
axis=2,
)
# Time masking
max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
num_frame_masks = min(
self.num_frame_masks,
math.ceil(max_tot_mask_frames / self.frames_mask_size),
)
max_mask_frames = min(
self.frames_mask_size, max_tot_mask_frames // num_frame_masks
)
features, time_mask = mask_along_axis_optimized(
features,
mask_size=max_mask_frames,
mask_times=num_frame_masks,
mask_value=mean,
axis=1,
return_time_mask=True,
)
return features, time_mask
def state_dict(self, **kwargs) -> Dict[str, Any]:
return dict(
time_warp_factor=self.time_warp_factor,
num_feature_masks=self.num_feature_masks,
features_mask_size=self.features_mask_size,
num_frame_masks=self.num_frame_masks,
frames_mask_size=self.frames_mask_size,
max_frames_mask_fraction=self.max_frames_mask_fraction,
p=self.p,
)
def load_state_dict(self, state_dict: Dict[str, Any]):
self.time_warp_factor = state_dict.get(
"time_warp_factor", self.time_warp_factor
)
self.num_feature_masks = state_dict.get(
"num_feature_masks", self.num_feature_masks
)
self.features_mask_size = state_dict.get(
"features_mask_size", self.features_mask_size
)
self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks)
self.frames_mask_size = state_dict.get(
"frames_mask_size", self.frames_mask_size
)
self.max_frames_mask_fraction = state_dict.get(
"max_frames_mask_fraction", self.max_frames_mask_fraction
)
self.p = state_dict.get("p", self.p)
def mask_along_axis_optimized(
features: torch.Tensor,
mask_size: int,
mask_times: int,
mask_value: float,
axis: int,
return_time_mask: bool = False,
) -> torch.Tensor:
"""
Apply Frequency and Time masking along axis.
Frequency and Time masking as described in the SpecAugment paper.
:param features: input tensor of shape ``(T, F)``
:mask_size: the width size for masking.
:mask_times: the number of masking regions.
:mask_value: Value to assign to the masked regions.
:axis: Axis to apply masking on (1 -> time, 2 -> frequency)
:return_time_mask: Whether return the time mask of shape ``(1, T)``
"""
if axis not in [1, 2]:
raise ValueError("Only Frequency and Time masking are supported!")
if return_time_mask and axis == 1:
time_mask = torch.zeros(
1, features.size(0), dtype=torch.bool, device=features.device
)
else:
time_mask = None
features = features.unsqueeze(0)
features = features.reshape([-1] + list(features.size()[-2:]))
values = torch.randint(int(0), int(mask_size), (1, mask_times))
min_values = torch.rand(1, mask_times) * (features.size(axis) - values)
mask_starts = (min_values.long()).squeeze()
mask_ends = (min_values.long() + values.long()).squeeze()
if axis == 1:
if mask_times == 1:
features[:, mask_starts:mask_ends] = mask_value
if return_time_mask:
time_mask[:, mask_starts:mask_ends] = True
return features.squeeze(0), time_mask
for (mask_start, mask_end) in zip(mask_starts, mask_ends):
features[:, mask_start:mask_end] = mask_value
if return_time_mask:
time_mask[:, mask_start:mask_end] = True
else:
if mask_times == 1:
features[:, :, mask_starts:mask_ends] = mask_value
return features.squeeze(0), time_mask
for (mask_start, mask_end) in zip(mask_starts, mask_ends):
features[:, :, mask_start:mask_end] = mask_value
features = features.squeeze(0)
return features, time_mask
def time_warp(
features: torch.Tensor,
p: float = 0.9,
time_warp_factor: Optional[int] = 80,
supervision_segments: Optional[torch.Tensor] = None,
):
if time_warp_factor is None or time_warp_factor < 1:
return features
assert len(features.shape) == 3, (
"SpecAugment only supports batches of single-channel feature matrices."
)
features = features.clone()
if supervision_segments is None:
# No supervisions - apply spec augment to full feature matrices.
for sequence_idx in range(features.size(0)):
if random.random() > p:
# Randomly choose whether this transform is applied
continue
features[sequence_idx] = time_warp_impl(
features[sequence_idx], factor=time_warp_factor
)
else:
# Supervisions provided - we will apply time warping only on the supervised areas.
for sequence_idx, start_frame, num_frames in supervision_segments:
if random.random() > p:
# Randomly choose whether this transform is applied
continue
end_frame = start_frame + num_frames
features[sequence_idx, start_frame:end_frame] = time_warp_impl(
features[sequence_idx, start_frame:end_frame], factor=time_warp_factor
)
return features

View File

@ -72,6 +72,7 @@ from attention_decoder import AttentionDecoderModel
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset import SpecAugment
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
@ -102,7 +103,6 @@ from icefall.utils import (
setup_logger, setup_logger,
str2bool, str2bool,
) )
from spec_augment import SpecAugment
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -460,22 +460,15 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--cr-loss-scale", "--cr-loss-scale",
type=float, type=float,
default=0.15, default=0.2,
help="Scale for consistency-regularization loss.", help="Scale for consistency-regularization loss.",
) )
parser.add_argument( parser.add_argument(
"--time-mask-ratio", "--time-mask-ratio",
type=float, type=float,
default=2.0, default=2.5,
help="When using cr-ctc, we increase the time-masking ratio.", help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.",
)
parser.add_argument(
"--cr-loss-masked-scale",
type=float,
default=1.0,
help="The value used to scale up the cr_loss at masked positions",
) )
parser.add_argument( parser.add_argument(
@ -950,7 +943,6 @@ def compute_loss(
spec_augment=spec_augment, spec_augment=spec_augment,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
time_warp_factor=params.spec_aug_time_warp_factor, time_warp_factor=params.spec_aug_time_warp_factor,
cr_loss_masked_scale=params.cr_loss_masked_scale,
) )
loss = 0.0 loss = 0.0

View File

@ -21,6 +21,7 @@ import argparse
import collections import collections
import logging import logging
import os import os
import random
import re import re
import subprocess import subprocess
from collections import defaultdict from collections import defaultdict
@ -38,6 +39,7 @@ import sentencepiece as spm
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
from pypinyin import lazy_pinyin, pinyin from pypinyin import lazy_pinyin, pinyin
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -2271,3 +2273,41 @@ def num_tokens(
if 0 in ans: if 0 in ans:
num_tokens -= 1 num_tokens -= 1
return num_tokens return num_tokens
# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py
def time_warp(
features: torch.Tensor,
p: float = 0.9,
time_warp_factor: Optional[int] = 80,
supervision_segments: Optional[torch.Tensor] = None,
):
"""Apply time warping on a batch of features
"""
if time_warp_factor is None or time_warp_factor < 1:
return features
assert len(features.shape) == 3, (
"SpecAugment only supports batches of single-channel feature matrices."
)
features = features.clone()
if supervision_segments is None:
# No supervisions - apply spec augment to full feature matrices.
for sequence_idx in range(features.size(0)):
if random.random() > p:
# Randomly choose whether this transform is applied
continue
features[sequence_idx] = time_warp_impl(
features[sequence_idx], factor=time_warp_factor
)
else:
# Supervisions provided - we will apply time warping only on the supervised areas.
for sequence_idx, start_frame, num_frames in supervision_segments:
if random.random() > p:
# Randomly choose whether this transform is applied
continue
end_frame = start_frame + num_frames
features[sequence_idx, start_frame:end_frame] = time_warp_impl(
features[sequence_idx, start_frame:end_frame], factor=time_warp_factor
)
return features