mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-14 11:34:20 +00:00
icefall format
This commit is contained in:
parent
03fa9957a6
commit
496fb963c8
@ -1,15 +1,10 @@
|
||||
import bisect
|
||||
import math
|
||||
import random
|
||||
from typing import Dict, Optional, Sequence, Tuple, TypeVar, Union
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lhotse import CutSet
|
||||
from lhotse.augmentation import dereverb_wpe_torch
|
||||
from lhotse.utils import Pathlike
|
||||
|
||||
|
||||
class SpecAugment(torch.nn.Module):
|
||||
"""
|
||||
@ -88,19 +83,26 @@ class SpecAugment(torch.nn.Module):
|
||||
:return: an augmented tensor of shape ``(B, T, F)``.
|
||||
"""
|
||||
assert len(features.shape) == 3, (
|
||||
"SpecAugment only supports batches of " "single-channel feature matrices."
|
||||
"SpecAugment only supports batches of "
|
||||
"single-channel feature matrices."
|
||||
)
|
||||
features = features.clone()
|
||||
if supervision_segments is None:
|
||||
# No supervisions - apply spec augment to full feature matrices.
|
||||
for sequence_idx in range(features.size(0)):
|
||||
features[sequence_idx] = self._forward_single(features[sequence_idx])
|
||||
features[sequence_idx] = self._forward_single(
|
||||
features[sequence_idx]
|
||||
)
|
||||
else:
|
||||
# Supervisions provided - we will apply time warping only on the supervised areas.
|
||||
for sequence_idx, start_frame, num_frames in supervision_segments:
|
||||
end_frame = start_frame + num_frames
|
||||
features[sequence_idx, start_frame:end_frame] = self._forward_single(
|
||||
features[sequence_idx, start_frame:end_frame], warp=True, mask=False
|
||||
features[
|
||||
sequence_idx, start_frame:end_frame
|
||||
] = self._forward_single(
|
||||
features[sequence_idx, start_frame:end_frame],
|
||||
warp=True,
|
||||
mask=False,
|
||||
)
|
||||
# ... and then time-mask the full feature matrices. Note that in this mode,
|
||||
# it might happen that masks are applied to different sequences/examples
|
||||
@ -134,7 +136,9 @@ class SpecAugment(torch.nn.Module):
|
||||
axis=2,
|
||||
)
|
||||
# Time masking
|
||||
max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
|
||||
max_tot_mask_frames = self.max_frames_mask_fraction * features.size(
|
||||
0
|
||||
)
|
||||
num_frame_masks = min(
|
||||
self.num_frame_masks,
|
||||
math.ceil(max_tot_mask_frames / self.frames_mask_size),
|
||||
@ -173,7 +177,9 @@ class SpecAugment(torch.nn.Module):
|
||||
self.features_mask_size = state_dict.get(
|
||||
"features_mask_size", self.features_mask_size
|
||||
)
|
||||
self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks)
|
||||
self.num_frame_masks = state_dict.get(
|
||||
"num_frame_masks", self.num_frame_masks
|
||||
)
|
||||
self.frames_mask_size = state_dict.get(
|
||||
"frames_mask_size", self.frames_mask_size
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user