icefall format

This commit is contained in:
Guo Liyong 2022-05-30 10:49:08 +08:00
parent 03fa9957a6
commit 496fb963c8

View File

@ -1,15 +1,10 @@
import bisect
import math import math
import random import random
from typing import Dict, Optional, Sequence, Tuple, TypeVar, Union from typing import Dict, Optional
import numpy as np import numpy as np
import torch import torch
from lhotse import CutSet
from lhotse.augmentation import dereverb_wpe_torch
from lhotse.utils import Pathlike
class SpecAugment(torch.nn.Module): class SpecAugment(torch.nn.Module):
""" """
@ -88,19 +83,26 @@ class SpecAugment(torch.nn.Module):
:return: an augmented tensor of shape ``(B, T, F)``. :return: an augmented tensor of shape ``(B, T, F)``.
""" """
assert len(features.shape) == 3, ( assert len(features.shape) == 3, (
"SpecAugment only supports batches of " "single-channel feature matrices." "SpecAugment only supports batches of "
"single-channel feature matrices."
) )
features = features.clone() features = features.clone()
if supervision_segments is None: if supervision_segments is None:
# No supervisions - apply spec augment to full feature matrices. # No supervisions - apply spec augment to full feature matrices.
for sequence_idx in range(features.size(0)): for sequence_idx in range(features.size(0)):
features[sequence_idx] = self._forward_single(features[sequence_idx]) features[sequence_idx] = self._forward_single(
features[sequence_idx]
)
else: else:
# Supervisions provided - we will apply time warping only on the supervised areas. # Supervisions provided - we will apply time warping only on the supervised areas.
for sequence_idx, start_frame, num_frames in supervision_segments: for sequence_idx, start_frame, num_frames in supervision_segments:
end_frame = start_frame + num_frames end_frame = start_frame + num_frames
features[sequence_idx, start_frame:end_frame] = self._forward_single( features[
features[sequence_idx, start_frame:end_frame], warp=True, mask=False sequence_idx, start_frame:end_frame
] = self._forward_single(
features[sequence_idx, start_frame:end_frame],
warp=True,
mask=False,
) )
# ... and then time-mask the full feature matrices. Note that in this mode, # ... and then time-mask the full feature matrices. Note that in this mode,
# it might happen that masks are applied to different sequences/examples # it might happen that masks are applied to different sequences/examples
@ -134,7 +136,9 @@ class SpecAugment(torch.nn.Module):
axis=2, axis=2,
) )
# Time masking # Time masking
max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) max_tot_mask_frames = self.max_frames_mask_fraction * features.size(
0
)
num_frame_masks = min( num_frame_masks = min(
self.num_frame_masks, self.num_frame_masks,
math.ceil(max_tot_mask_frames / self.frames_mask_size), math.ceil(max_tot_mask_frames / self.frames_mask_size),
@ -173,7 +177,9 @@ class SpecAugment(torch.nn.Module):
self.features_mask_size = state_dict.get( self.features_mask_size = state_dict.get(
"features_mask_size", self.features_mask_size "features_mask_size", self.features_mask_size
) )
self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) self.num_frame_masks = state_dict.get(
"num_frame_masks", self.num_frame_masks
)
self.frames_mask_size = state_dict.get( self.frames_mask_size = state_dict.get(
"frames_mask_size", self.frames_mask_size "frames_mask_size", self.frames_mask_size
) )