icefall format

This commit is contained in:
Guo Liyong 2022-05-30 10:49:08 +08:00
parent 03fa9957a6
commit 496fb963c8

View File

@ -1,15 +1,10 @@
import bisect
import math
import random
from typing import Dict, Optional, Sequence, Tuple, TypeVar, Union
from typing import Dict, Optional
import numpy as np
import torch
from lhotse import CutSet
from lhotse.augmentation import dereverb_wpe_torch
from lhotse.utils import Pathlike
class SpecAugment(torch.nn.Module):
"""
@ -88,19 +83,26 @@ class SpecAugment(torch.nn.Module):
:return: an augmented tensor of shape ``(B, T, F)``.
"""
assert len(features.shape) == 3, (
"SpecAugment only supports batches of " "single-channel feature matrices."
"SpecAugment only supports batches of "
"single-channel feature matrices."
)
features = features.clone()
if supervision_segments is None:
# No supervisions - apply spec augment to full feature matrices.
for sequence_idx in range(features.size(0)):
features[sequence_idx] = self._forward_single(features[sequence_idx])
features[sequence_idx] = self._forward_single(
features[sequence_idx]
)
else:
# Supervisions provided - we will apply time warping only on the supervised areas.
for sequence_idx, start_frame, num_frames in supervision_segments:
end_frame = start_frame + num_frames
features[sequence_idx, start_frame:end_frame] = self._forward_single(
features[sequence_idx, start_frame:end_frame], warp=True, mask=False
features[
sequence_idx, start_frame:end_frame
] = self._forward_single(
features[sequence_idx, start_frame:end_frame],
warp=True,
mask=False,
)
# ... and then time-mask the full feature matrices. Note that in this mode,
# it might happen that masks are applied to different sequences/examples
@ -134,7 +136,9 @@ class SpecAugment(torch.nn.Module):
axis=2,
)
# Time masking
max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0)
max_tot_mask_frames = self.max_frames_mask_fraction * features.size(
0
)
num_frame_masks = min(
self.num_frame_masks,
math.ceil(max_tot_mask_frames / self.frames_mask_size),
@ -173,7 +177,9 @@ class SpecAugment(torch.nn.Module):
self.features_mask_size = state_dict.get(
"features_mask_size", self.features_mask_size
)
self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks)
self.num_frame_masks = state_dict.get(
"num_frame_masks", self.num_frame_masks
)
self.frames_mask_size = state_dict.get(
"frames_mask_size", self.frames_mask_size
)