mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-16 20:52:18 +00:00
refactor codes
This commit is contained in:
parent
a6eead6c98
commit
ae59e5d61e
@ -24,8 +24,8 @@ import torch.nn as nn
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask, time_warp
|
||||||
from spec_augment import SpecAugment, time_warp
|
from lhotse.dataset import SpecAugment
|
||||||
|
|
||||||
|
|
||||||
class AsrModel(nn.Module):
|
class AsrModel(nn.Module):
|
||||||
@ -188,8 +188,6 @@ class AsrModel(nn.Module):
|
|||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
targets: torch.Tensor,
|
targets: torch.Tensor,
|
||||||
target_lengths: torch.Tensor,
|
target_lengths: torch.Tensor,
|
||||||
time_mask: Optional[torch.Tensor] = None,
|
|
||||||
cr_loss_masked_scale: float = 1.0,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Compute CTC loss with consistency regularization loss.
|
"""Compute CTC loss with consistency regularization loss.
|
||||||
Args:
|
Args:
|
||||||
@ -200,10 +198,6 @@ class AsrModel(nn.Module):
|
|||||||
targets:
|
targets:
|
||||||
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
|
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
|
||||||
to be un-padded and concatenated within 1 dimension.
|
to be un-padded and concatenated within 1 dimension.
|
||||||
time_mask:
|
|
||||||
Downsampled time masks of shape (2 * N, T, 1).
|
|
||||||
cr_loss_masked_scale:
|
|
||||||
The loss scale used to scale up the cr_loss at masked positions.
|
|
||||||
"""
|
"""
|
||||||
# Compute CTC loss
|
# Compute CTC loss
|
||||||
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
|
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
|
||||||
@ -226,14 +220,6 @@ class AsrModel(nn.Module):
|
|||||||
reduction="none",
|
reduction="none",
|
||||||
log_target=True,
|
log_target=True,
|
||||||
) # (2 * N, T, C)
|
) # (2 * N, T, C)
|
||||||
if time_mask is not None:
|
|
||||||
assert time_mask.shape[:-1] == ctc_output.shape[:-1], (
|
|
||||||
time_mask.shape, ctc_output.shape
|
|
||||||
)
|
|
||||||
masked_scale = time_mask * (cr_loss_masked_scale - 1) + 1
|
|
||||||
# e.g., if cr_loss_masked_scale = 3, scales at masked positions are 3,
|
|
||||||
# scales at unmasked positions are 1
|
|
||||||
cr_loss = cr_loss * masked_scale # scaling up masked positions
|
|
||||||
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
|
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
|
||||||
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
|
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
|
||||||
|
|
||||||
@ -359,7 +345,6 @@ class AsrModel(nn.Module):
|
|||||||
spec_augment: Optional[SpecAugment] = None,
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
supervision_segments: Optional[torch.Tensor] = None,
|
supervision_segments: Optional[torch.Tensor] = None,
|
||||||
time_warp_factor: Optional[int] = 80,
|
time_warp_factor: Optional[int] = 80,
|
||||||
cr_loss_masked_scale: float = 1.0,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -395,8 +380,6 @@ class AsrModel(nn.Module):
|
|||||||
Parameter for the time warping; larger values mean more warping.
|
Parameter for the time warping; larger values mean more warping.
|
||||||
Set to ``None``, or less than ``1``, to disable.
|
Set to ``None``, or less than ``1``, to disable.
|
||||||
Used only if use_cr_ctc is True.
|
Used only if use_cr_ctc is True.
|
||||||
cr_loss_masked_scale:
|
|
||||||
The loss scale used to scale up the cr_loss at masked positions.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer losses, CTC loss, AED loss,
|
Return the transducer losses, CTC loss, AED loss,
|
||||||
@ -429,12 +412,9 @@ class AsrModel(nn.Module):
|
|||||||
supervision_segments=supervision_segments,
|
supervision_segments=supervision_segments,
|
||||||
)
|
)
|
||||||
# Independently apply frequency masking and time masking to the two copies
|
# Independently apply frequency masking and time masking to the two copies
|
||||||
x, time_mask = spec_augment(x.repeat(2, 1, 1))
|
x = spec_augment(x.repeat(2, 1, 1))
|
||||||
# time_mask: 1 for masked, 0 for unmasked
|
|
||||||
time_mask = downsample_time_mask(time_mask, x.dtype)
|
|
||||||
else:
|
else:
|
||||||
x = x.repeat(2, 1, 1)
|
x = x.repeat(2, 1, 1)
|
||||||
time_mask = None
|
|
||||||
x_lens = x_lens.repeat(2)
|
x_lens = x_lens.repeat(2)
|
||||||
y = k2.ragged.cat([y, y], axis=0)
|
y = k2.ragged.cat([y, y], axis=0)
|
||||||
|
|
||||||
@ -479,8 +459,6 @@ class AsrModel(nn.Module):
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
target_lengths=y_lens,
|
target_lengths=y_lens,
|
||||||
time_mask=time_mask,
|
|
||||||
cr_loss_masked_scale=cr_loss_masked_scale,
|
|
||||||
)
|
)
|
||||||
ctc_loss = ctc_loss * 0.5
|
ctc_loss = ctc_loss * 0.5
|
||||||
cr_loss = cr_loss * 0.5
|
cr_loss = cr_loss * 0.5
|
||||||
@ -501,31 +479,3 @@ class AsrModel(nn.Module):
|
|||||||
attention_decoder_loss = torch.empty(0)
|
attention_decoder_loss = torch.empty(0)
|
||||||
|
|
||||||
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss
|
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss
|
||||||
|
|
||||||
|
|
||||||
def downsample_time_mask(time_mask: torch.Tensor, dtype: torch.dtype):
|
|
||||||
"""Downsample the time masks as in Zipformer.
|
|
||||||
Args:
|
|
||||||
time_mask: shape of (N, T)
|
|
||||||
Returns:
|
|
||||||
The downsampled time masks of shape (N, T', 1),
|
|
||||||
where T' = ((T - 7) // 2 + 1) // 2
|
|
||||||
"""
|
|
||||||
# Downsample the time masks as in Zipformer
|
|
||||||
time_mask = time_mask.to(dtype).unsqueeze(dim=1)
|
|
||||||
# as in conv-embed
|
|
||||||
time_mask = nn.functional.max_pool1d(
|
|
||||||
time_mask, kernel_size=3, stride=1, padding=0
|
|
||||||
) # T - 2
|
|
||||||
time_mask = nn.functional.max_pool1d(
|
|
||||||
time_mask, kernel_size=3, stride=2, padding=0
|
|
||||||
) # (T - 3) // 2
|
|
||||||
time_mask = nn.functional.max_pool1d(
|
|
||||||
time_mask, kernel_size=3, stride=1, padding=0
|
|
||||||
) # (T - 7) // 2
|
|
||||||
# as in output-downsampling
|
|
||||||
time_mask = nn.functional.max_pool1d(
|
|
||||||
time_mask, kernel_size=2, stride=2, padding=0, ceil_mode=True
|
|
||||||
)
|
|
||||||
time_mask = time_mask.transpose(1, 2) # (N * 2, T', 1)
|
|
||||||
return time_mask
|
|
||||||
|
@ -1,313 +0,0 @@
|
|||||||
# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py
|
|
||||||
# with minor modification for cr-ctc training.
|
|
||||||
|
|
||||||
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
|
|
||||||
|
|
||||||
|
|
||||||
class SpecAugment(torch.nn.Module):
|
|
||||||
"""SpecAugment from lhotse with minor modification, returning time masks.
|
|
||||||
|
|
||||||
SpecAugment performs three augmentations:
|
|
||||||
- time warping of the feature matrix
|
|
||||||
- masking of ranges of features (frequency bands)
|
|
||||||
- masking of ranges of frames (time)
|
|
||||||
|
|
||||||
The current implementation works with batches, but processes each example separately
|
|
||||||
in a loop rather than simultaneously to achieve different augmentation parameters for
|
|
||||||
each example.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
time_warp_factor: Optional[int] = 80,
|
|
||||||
num_feature_masks: int = 2,
|
|
||||||
features_mask_size: int = 27,
|
|
||||||
num_frame_masks: int = 10,
|
|
||||||
frames_mask_size: int = 100,
|
|
||||||
max_frames_mask_fraction: float = 0.15,
|
|
||||||
p=0.9,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
SpecAugment's constructor.
|
|
||||||
|
|
||||||
:param time_warp_factor: parameter for the time warping; larger values mean more warping.
|
|
||||||
Set to ``None``, or less than ``1``, to disable.
|
|
||||||
:param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable.
|
|
||||||
:param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins).
|
|
||||||
This is the ``F`` parameter from the SpecAugment paper.
|
|
||||||
:param num_frame_masks: the number of masking regions for utterances. Set to ``0`` to disable.
|
|
||||||
:param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames).
|
|
||||||
This is the ``T`` parameter from the SpecAugment paper.
|
|
||||||
:param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length
|
|
||||||
of the utterance (or supervision segment).
|
|
||||||
This is the parameter denoted by ``p`` in the SpecAugment paper.
|
|
||||||
:param p: the probability of applying this transform.
|
|
||||||
It is different from ``p`` in the SpecAugment paper!
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
assert 0 <= p <= 1
|
|
||||||
assert num_feature_masks >= 0
|
|
||||||
assert num_frame_masks >= 0
|
|
||||||
assert features_mask_size > 0
|
|
||||||
assert frames_mask_size > 0
|
|
||||||
self.time_warp_factor = time_warp_factor
|
|
||||||
self.num_feature_masks = num_feature_masks
|
|
||||||
self.features_mask_size = features_mask_size
|
|
||||||
self.num_frame_masks = num_frame_masks
|
|
||||||
self.frames_mask_size = frames_mask_size
|
|
||||||
self.max_frames_mask_fraction = max_frames_mask_fraction
|
|
||||||
self.p = p
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
features: torch.Tensor,
|
|
||||||
supervision_segments: Optional[torch.IntTensor] = None,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Computes SpecAugment for a batch of feature matrices.
|
|
||||||
|
|
||||||
Since the batch will usually already be padded, the user can optionally
|
|
||||||
provide a ``supervision_segments`` tensor that will be used to apply SpecAugment
|
|
||||||
only to selected areas of the input. The format of this input is described below.
|
|
||||||
|
|
||||||
:param features: a batch of feature matrices with shape ``(B, T, F)``.
|
|
||||||
:param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of
|
|
||||||
supervision segments that exist in ``features`` -- there may be either
|
|
||||||
less or more than the batch size.
|
|
||||||
The second dimension encoder three kinds of information:
|
|
||||||
the sequence index of the corresponding feature matrix in `features`,
|
|
||||||
the start frame index, and the number of frames for each segment.
|
|
||||||
:return:
|
|
||||||
- an augmented tensor of shape ``(B, T, F)``.
|
|
||||||
- the corresponding time masks of shape ``(B, T)``.
|
|
||||||
"""
|
|
||||||
assert len(features.shape) == 3, (
|
|
||||||
"SpecAugment only supports batches of " "single-channel feature matrices."
|
|
||||||
)
|
|
||||||
features = features.clone()
|
|
||||||
|
|
||||||
time_masks = []
|
|
||||||
|
|
||||||
if supervision_segments is None:
|
|
||||||
# No supervisions - apply spec augment to full feature matrices.
|
|
||||||
for sequence_idx in range(features.size(0)):
|
|
||||||
masked_feature, time_mask = self._forward_single(features[sequence_idx])
|
|
||||||
features[sequence_idx] = masked_feature
|
|
||||||
time_masks.append(time_mask)
|
|
||||||
else:
|
|
||||||
# Supervisions provided - we will apply time warping only on the supervised areas.
|
|
||||||
for sequence_idx, start_frame, num_frames in supervision_segments:
|
|
||||||
end_frame = start_frame + num_frames
|
|
||||||
warped_feature, _ = self._forward_single(
|
|
||||||
features[sequence_idx, start_frame:end_frame], warp=True, mask=False
|
|
||||||
)
|
|
||||||
features[sequence_idx, start_frame:end_frame] = warped_feature
|
|
||||||
# ... and then time-mask the full feature matrices. Note that in this mode,
|
|
||||||
# it might happen that masks are applied to different sequences/examples
|
|
||||||
# than the time warping.
|
|
||||||
for sequence_idx in range(features.size(0)):
|
|
||||||
masked_feature, time_mask = self._forward_single(
|
|
||||||
features[sequence_idx], warp=False, mask=True
|
|
||||||
)
|
|
||||||
features[sequence_idx] = masked_feature
|
|
||||||
time_masks.append(time_mask)
|
|
||||||
|
|
||||||
time_masks = torch.cat(time_masks, dim=0)
|
|
||||||
assert time_masks.shape == features.shape[:-1], (time_masks.shape == features.shape[:-1])
|
|
||||||
return features, time_masks
|
|
||||||
|
|
||||||
def _forward_single(
|
|
||||||
self, features: torch.Tensor, warp: bool = True, mask: bool = True
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Apply SpecAugment to a single feature matrix of shape (T, F).
|
|
||||||
"""
|
|
||||||
if random.random() > self.p:
|
|
||||||
# Randomly choose whether this transform is applied
|
|
||||||
time_mask = torch.zeros(
|
|
||||||
1, features.size(0), dtype=torch.bool, device=features.device
|
|
||||||
)
|
|
||||||
return features, time_mask
|
|
||||||
|
|
||||||
time_mask = None
|
|
||||||
if warp:
|
|
||||||
if self.time_warp_factor is not None and self.time_warp_factor >= 1:
|
|
||||||
features = time_warp_impl(features, factor=self.time_warp_factor)
|
|
||||||
|
|
||||||
if mask:
|
|
||||||
mean = features.mean()
|
|
||||||
# Frequency masking
|
|
||||||
features, _ = mask_along_axis_optimized(
|
|
||||||
features,
|
|
||||||
mask_size=self.features_mask_size,
|
|
||||||
mask_times=self.num_feature_masks,
|
|
||||||
mask_value=mean,
|
|
||||||
axis=2,
|
|
||||||
)
|
|
||||||
# Time masking
|
|
||||||
max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
|
|
||||||
num_frame_masks = min(
|
|
||||||
self.num_frame_masks,
|
|
||||||
math.ceil(max_tot_mask_frames / self.frames_mask_size),
|
|
||||||
)
|
|
||||||
max_mask_frames = min(
|
|
||||||
self.frames_mask_size, max_tot_mask_frames // num_frame_masks
|
|
||||||
)
|
|
||||||
features, time_mask = mask_along_axis_optimized(
|
|
||||||
features,
|
|
||||||
mask_size=max_mask_frames,
|
|
||||||
mask_times=num_frame_masks,
|
|
||||||
mask_value=mean,
|
|
||||||
axis=1,
|
|
||||||
return_time_mask=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return features, time_mask
|
|
||||||
|
|
||||||
def state_dict(self, **kwargs) -> Dict[str, Any]:
|
|
||||||
return dict(
|
|
||||||
time_warp_factor=self.time_warp_factor,
|
|
||||||
num_feature_masks=self.num_feature_masks,
|
|
||||||
features_mask_size=self.features_mask_size,
|
|
||||||
num_frame_masks=self.num_frame_masks,
|
|
||||||
frames_mask_size=self.frames_mask_size,
|
|
||||||
max_frames_mask_fraction=self.max_frames_mask_fraction,
|
|
||||||
p=self.p,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, Any]):
|
|
||||||
self.time_warp_factor = state_dict.get(
|
|
||||||
"time_warp_factor", self.time_warp_factor
|
|
||||||
)
|
|
||||||
self.num_feature_masks = state_dict.get(
|
|
||||||
"num_feature_masks", self.num_feature_masks
|
|
||||||
)
|
|
||||||
self.features_mask_size = state_dict.get(
|
|
||||||
"features_mask_size", self.features_mask_size
|
|
||||||
)
|
|
||||||
self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks)
|
|
||||||
self.frames_mask_size = state_dict.get(
|
|
||||||
"frames_mask_size", self.frames_mask_size
|
|
||||||
)
|
|
||||||
self.max_frames_mask_fraction = state_dict.get(
|
|
||||||
"max_frames_mask_fraction", self.max_frames_mask_fraction
|
|
||||||
)
|
|
||||||
self.p = state_dict.get("p", self.p)
|
|
||||||
|
|
||||||
|
|
||||||
def mask_along_axis_optimized(
|
|
||||||
features: torch.Tensor,
|
|
||||||
mask_size: int,
|
|
||||||
mask_times: int,
|
|
||||||
mask_value: float,
|
|
||||||
axis: int,
|
|
||||||
return_time_mask: bool = False,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Apply Frequency and Time masking along axis.
|
|
||||||
Frequency and Time masking as described in the SpecAugment paper.
|
|
||||||
|
|
||||||
:param features: input tensor of shape ``(T, F)``
|
|
||||||
:mask_size: the width size for masking.
|
|
||||||
:mask_times: the number of masking regions.
|
|
||||||
:mask_value: Value to assign to the masked regions.
|
|
||||||
:axis: Axis to apply masking on (1 -> time, 2 -> frequency)
|
|
||||||
:return_time_mask: Whether return the time mask of shape ``(1, T)``
|
|
||||||
"""
|
|
||||||
if axis not in [1, 2]:
|
|
||||||
raise ValueError("Only Frequency and Time masking are supported!")
|
|
||||||
|
|
||||||
if return_time_mask and axis == 1:
|
|
||||||
time_mask = torch.zeros(
|
|
||||||
1, features.size(0), dtype=torch.bool, device=features.device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
time_mask = None
|
|
||||||
|
|
||||||
features = features.unsqueeze(0)
|
|
||||||
features = features.reshape([-1] + list(features.size()[-2:]))
|
|
||||||
|
|
||||||
values = torch.randint(int(0), int(mask_size), (1, mask_times))
|
|
||||||
min_values = torch.rand(1, mask_times) * (features.size(axis) - values)
|
|
||||||
mask_starts = (min_values.long()).squeeze()
|
|
||||||
mask_ends = (min_values.long() + values.long()).squeeze()
|
|
||||||
|
|
||||||
if axis == 1:
|
|
||||||
if mask_times == 1:
|
|
||||||
features[:, mask_starts:mask_ends] = mask_value
|
|
||||||
if return_time_mask:
|
|
||||||
time_mask[:, mask_starts:mask_ends] = True
|
|
||||||
return features.squeeze(0), time_mask
|
|
||||||
for (mask_start, mask_end) in zip(mask_starts, mask_ends):
|
|
||||||
features[:, mask_start:mask_end] = mask_value
|
|
||||||
if return_time_mask:
|
|
||||||
time_mask[:, mask_start:mask_end] = True
|
|
||||||
else:
|
|
||||||
if mask_times == 1:
|
|
||||||
features[:, :, mask_starts:mask_ends] = mask_value
|
|
||||||
return features.squeeze(0), time_mask
|
|
||||||
for (mask_start, mask_end) in zip(mask_starts, mask_ends):
|
|
||||||
features[:, :, mask_start:mask_end] = mask_value
|
|
||||||
|
|
||||||
features = features.squeeze(0)
|
|
||||||
return features, time_mask
|
|
||||||
|
|
||||||
|
|
||||||
def time_warp(
|
|
||||||
features: torch.Tensor,
|
|
||||||
p: float = 0.9,
|
|
||||||
time_warp_factor: Optional[int] = 80,
|
|
||||||
supervision_segments: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
if time_warp_factor is None or time_warp_factor < 1:
|
|
||||||
return features
|
|
||||||
assert len(features.shape) == 3, (
|
|
||||||
"SpecAugment only supports batches of single-channel feature matrices."
|
|
||||||
)
|
|
||||||
features = features.clone()
|
|
||||||
if supervision_segments is None:
|
|
||||||
# No supervisions - apply spec augment to full feature matrices.
|
|
||||||
for sequence_idx in range(features.size(0)):
|
|
||||||
if random.random() > p:
|
|
||||||
# Randomly choose whether this transform is applied
|
|
||||||
continue
|
|
||||||
features[sequence_idx] = time_warp_impl(
|
|
||||||
features[sequence_idx], factor=time_warp_factor
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Supervisions provided - we will apply time warping only on the supervised areas.
|
|
||||||
for sequence_idx, start_frame, num_frames in supervision_segments:
|
|
||||||
if random.random() > p:
|
|
||||||
# Randomly choose whether this transform is applied
|
|
||||||
continue
|
|
||||||
end_frame = start_frame + num_frames
|
|
||||||
features[sequence_idx, start_frame:end_frame] = time_warp_impl(
|
|
||||||
features[sequence_idx, start_frame:end_frame], factor=time_warp_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
return features
|
|
@ -72,6 +72,7 @@ from attention_decoder import AttentionDecoderModel
|
|||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
|
from lhotse.dataset import SpecAugment
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import AsrModel
|
from model import AsrModel
|
||||||
@ -102,7 +103,6 @@ from icefall.utils import (
|
|||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
from spec_augment import SpecAugment
|
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -460,22 +460,15 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cr-loss-scale",
|
"--cr-loss-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.15,
|
default=0.2,
|
||||||
help="Scale for consistency-regularization loss.",
|
help="Scale for consistency-regularization loss.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--time-mask-ratio",
|
"--time-mask-ratio",
|
||||||
type=float,
|
type=float,
|
||||||
default=2.0,
|
default=2.5,
|
||||||
help="When using cr-ctc, we increase the time-masking ratio.",
|
help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.",
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--cr-loss-masked-scale",
|
|
||||||
type=float,
|
|
||||||
default=1.0,
|
|
||||||
help="The value used to scale up the cr_loss at masked positions",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -950,7 +943,6 @@ def compute_loss(
|
|||||||
spec_augment=spec_augment,
|
spec_augment=spec_augment,
|
||||||
supervision_segments=supervision_segments,
|
supervision_segments=supervision_segments,
|
||||||
time_warp_factor=params.spec_aug_time_warp_factor,
|
time_warp_factor=params.spec_aug_time_warp_factor,
|
||||||
cr_loss_masked_scale=params.cr_loss_masked_scale,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
@ -21,6 +21,7 @@ import argparse
|
|||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -38,6 +39,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
|
||||||
from pypinyin import lazy_pinyin, pinyin
|
from pypinyin import lazy_pinyin, pinyin
|
||||||
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
|
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@ -2271,3 +2273,41 @@ def num_tokens(
|
|||||||
if 0 in ans:
|
if 0 in ans:
|
||||||
num_tokens -= 1
|
num_tokens -= 1
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
|
# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py
|
||||||
|
def time_warp(
|
||||||
|
features: torch.Tensor,
|
||||||
|
p: float = 0.9,
|
||||||
|
time_warp_factor: Optional[int] = 80,
|
||||||
|
supervision_segments: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""Apply time warping on a batch of features
|
||||||
|
"""
|
||||||
|
if time_warp_factor is None or time_warp_factor < 1:
|
||||||
|
return features
|
||||||
|
assert len(features.shape) == 3, (
|
||||||
|
"SpecAugment only supports batches of single-channel feature matrices."
|
||||||
|
)
|
||||||
|
features = features.clone()
|
||||||
|
if supervision_segments is None:
|
||||||
|
# No supervisions - apply spec augment to full feature matrices.
|
||||||
|
for sequence_idx in range(features.size(0)):
|
||||||
|
if random.random() > p:
|
||||||
|
# Randomly choose whether this transform is applied
|
||||||
|
continue
|
||||||
|
features[sequence_idx] = time_warp_impl(
|
||||||
|
features[sequence_idx], factor=time_warp_factor
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Supervisions provided - we will apply time warping only on the supervised areas.
|
||||||
|
for sequence_idx, start_frame, num_frames in supervision_segments:
|
||||||
|
if random.random() > p:
|
||||||
|
# Randomly choose whether this transform is applied
|
||||||
|
continue
|
||||||
|
end_frame = start_frame + num_frames
|
||||||
|
features[sequence_idx, start_frame:end_frame] = time_warp_impl(
|
||||||
|
features[sequence_idx, start_frame:end_frame], factor=time_warp_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
return features
|
||||||
|
Loading…
x
Reference in New Issue
Block a user