From 496fb963c8387fb9466ee898c4f1c7aa27115e1d Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 30 May 2022 10:49:08 +0800 Subject: [PATCH] icefall format --- .../ASR/pruned_transducer_stateless6/aug.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py b/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py index 202005d98..c60d328c5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/aug.py @@ -1,15 +1,10 @@ -import bisect import math import random -from typing import Dict, Optional, Sequence, Tuple, TypeVar, Union +from typing import Dict, Optional import numpy as np import torch -from lhotse import CutSet -from lhotse.augmentation import dereverb_wpe_torch -from lhotse.utils import Pathlike - class SpecAugment(torch.nn.Module): """ @@ -88,19 +83,26 @@ class SpecAugment(torch.nn.Module): :return: an augmented tensor of shape ``(B, T, F)``. """ assert len(features.shape) == 3, ( - "SpecAugment only supports batches of " "single-channel feature matrices." + "SpecAugment only supports batches of " + "single-channel feature matrices." ) features = features.clone() if supervision_segments is None: # No supervisions - apply spec augment to full feature matrices. for sequence_idx in range(features.size(0)): - features[sequence_idx] = self._forward_single(features[sequence_idx]) + features[sequence_idx] = self._forward_single( + features[sequence_idx] + ) else: # Supervisions provided - we will apply time warping only on the supervised areas. for sequence_idx, start_frame, num_frames in supervision_segments: end_frame = start_frame + num_frames - features[sequence_idx, start_frame:end_frame] = self._forward_single( - features[sequence_idx, start_frame:end_frame], warp=True, mask=False + features[ + sequence_idx, start_frame:end_frame + ] = self._forward_single( + features[sequence_idx, start_frame:end_frame], + warp=True, + mask=False, ) # ... and then time-mask the full feature matrices. Note that in this mode, # it might happen that masks are applied to different sequences/examples @@ -134,7 +136,9 @@ class SpecAugment(torch.nn.Module): axis=2, ) # Time masking - max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) + max_tot_mask_frames = self.max_frames_mask_fraction * features.size( + 0 + ) num_frame_masks = min( self.num_frame_masks, math.ceil(max_tot_mask_frames / self.frames_mask_size), @@ -173,7 +177,9 @@ class SpecAugment(torch.nn.Module): self.features_mask_size = state_dict.get( "features_mask_size", self.features_mask_size ) - self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) + self.num_frame_masks = state_dict.get( + "num_frame_masks", self.num_frame_masks + ) self.frames_mask_size = state_dict.get( "frames_mask_size", self.frames_mask_size )