mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-14 19:44:21 +00:00
icefall format
This commit is contained in:
parent
03fa9957a6
commit
496fb963c8
@ -1,15 +1,10 @@
|
|||||||
import bisect
|
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Dict, Optional, Sequence, Tuple, TypeVar, Union
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lhotse import CutSet
|
|
||||||
from lhotse.augmentation import dereverb_wpe_torch
|
|
||||||
from lhotse.utils import Pathlike
|
|
||||||
|
|
||||||
|
|
||||||
class SpecAugment(torch.nn.Module):
|
class SpecAugment(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -88,19 +83,26 @@ class SpecAugment(torch.nn.Module):
|
|||||||
:return: an augmented tensor of shape ``(B, T, F)``.
|
:return: an augmented tensor of shape ``(B, T, F)``.
|
||||||
"""
|
"""
|
||||||
assert len(features.shape) == 3, (
|
assert len(features.shape) == 3, (
|
||||||
"SpecAugment only supports batches of " "single-channel feature matrices."
|
"SpecAugment only supports batches of "
|
||||||
|
"single-channel feature matrices."
|
||||||
)
|
)
|
||||||
features = features.clone()
|
features = features.clone()
|
||||||
if supervision_segments is None:
|
if supervision_segments is None:
|
||||||
# No supervisions - apply spec augment to full feature matrices.
|
# No supervisions - apply spec augment to full feature matrices.
|
||||||
for sequence_idx in range(features.size(0)):
|
for sequence_idx in range(features.size(0)):
|
||||||
features[sequence_idx] = self._forward_single(features[sequence_idx])
|
features[sequence_idx] = self._forward_single(
|
||||||
|
features[sequence_idx]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Supervisions provided - we will apply time warping only on the supervised areas.
|
# Supervisions provided - we will apply time warping only on the supervised areas.
|
||||||
for sequence_idx, start_frame, num_frames in supervision_segments:
|
for sequence_idx, start_frame, num_frames in supervision_segments:
|
||||||
end_frame = start_frame + num_frames
|
end_frame = start_frame + num_frames
|
||||||
features[sequence_idx, start_frame:end_frame] = self._forward_single(
|
features[
|
||||||
features[sequence_idx, start_frame:end_frame], warp=True, mask=False
|
sequence_idx, start_frame:end_frame
|
||||||
|
] = self._forward_single(
|
||||||
|
features[sequence_idx, start_frame:end_frame],
|
||||||
|
warp=True,
|
||||||
|
mask=False,
|
||||||
)
|
)
|
||||||
# ... and then time-mask the full feature matrices. Note that in this mode,
|
# ... and then time-mask the full feature matrices. Note that in this mode,
|
||||||
# it might happen that masks are applied to different sequences/examples
|
# it might happen that masks are applied to different sequences/examples
|
||||||
@ -134,7 +136,9 @@ class SpecAugment(torch.nn.Module):
|
|||||||
axis=2,
|
axis=2,
|
||||||
)
|
)
|
||||||
# Time masking
|
# Time masking
|
||||||
max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
|
max_tot_mask_frames = self.max_frames_mask_fraction * features.size(
|
||||||
|
0
|
||||||
|
)
|
||||||
num_frame_masks = min(
|
num_frame_masks = min(
|
||||||
self.num_frame_masks,
|
self.num_frame_masks,
|
||||||
math.ceil(max_tot_mask_frames / self.frames_mask_size),
|
math.ceil(max_tot_mask_frames / self.frames_mask_size),
|
||||||
@ -173,7 +177,9 @@ class SpecAugment(torch.nn.Module):
|
|||||||
self.features_mask_size = state_dict.get(
|
self.features_mask_size = state_dict.get(
|
||||||
"features_mask_size", self.features_mask_size
|
"features_mask_size", self.features_mask_size
|
||||||
)
|
)
|
||||||
self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks)
|
self.num_frame_masks = state_dict.get(
|
||||||
|
"num_frame_masks", self.num_frame_masks
|
||||||
|
)
|
||||||
self.frames_mask_size = state_dict.get(
|
self.frames_mask_size = state_dict.get(
|
||||||
"frames_mask_size", self.frames_mask_size
|
"frames_mask_size", self.frames_mask_size
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user