mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Apply random frame shift along the time axis.
This commit is contained in:
parent
35ecd7e562
commit
8653b6a68a
@ -19,7 +19,9 @@ import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
@ -179,7 +181,27 @@ class LibriSpeechAsrDataModule:
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
extra_input_transforms: Optional[
|
||||
List[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
|
||||
],
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
The cutset for training.
|
||||
extra_input_transforms:
|
||||
The extra input transforms that will be applied after all input
|
||||
transforms, e.g., after SpecAugment if there is any.
|
||||
Each input transform accepts two input arguments:
|
||||
- A 3-D torch.Tensor of shape (N, T, C)
|
||||
- A 2-D torch.Tensor of shape (num_seqs, 3), where the
|
||||
first column is `sequence_idx`, the second column is
|
||||
`start_frame`, and the third column is `num_frames`.
|
||||
and returns a 3-D torch.Tensor of shape (N, T, C).
|
||||
"""
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "cuts_musan.json.gz"
|
||||
@ -228,6 +250,10 @@ class LibriSpeechAsrDataModule:
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
if extra_input_transforms is not None:
|
||||
input_transforms += extra_input_transforms
|
||||
logging.info(f"Input transforms: {input_transforms}")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
|
84
egs/librispeech/ASR/transducer_stateless/frame_shift.py
Normal file
84
egs/librispeech/ASR/transducer_stateless/frame_shift.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from lhotse.utils import LOG_EPSILON
|
||||
|
||||
|
||||
def apply_frame_shift(
|
||||
features: torch.Tensor,
|
||||
supervision_segments: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Apply random frame shift along the time axis.
|
||||
|
||||
For instance, for the input frame `[a, b, c, d]`,
|
||||
|
||||
- If frame shift is 0, the resulting output is `[a, b, c, d]`
|
||||
- If frame shift is -1, the resulting output is `[b, c, d, a]`
|
||||
- If frame shift is 1, the resulting output is `[d, a, b, c]`
|
||||
- If frame shift is 2, the resulting output is `[c, d, a, b]`
|
||||
|
||||
Args:
|
||||
features:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
supervision_segments:
|
||||
A 2-D tensor of shape (num_seqs, 3). The first column is
|
||||
`sequence_idx`, the second column is `start_frame`, and
|
||||
the third column is `num_frames`.
|
||||
Returns:
|
||||
Return a 3-D tensor of shape (N, T, C).
|
||||
"""
|
||||
# We assume the subsampling_factor is 4. If you change the
|
||||
# subsampling_factor, you should also change the following
|
||||
# list accordingly
|
||||
#
|
||||
# The value in frame_shifts is selected in such a way that
|
||||
# "value % subsampling_factor" is not duplicated in frame_shifts.
|
||||
frame_shifts = [-1, 0, 1, 2]
|
||||
|
||||
N = features.size(0)
|
||||
|
||||
# We don't support cut concatenation here
|
||||
assert torch.all(
|
||||
torch.eq(supervision_segments[:, 0], torch.arange(N))
|
||||
), supervision_segments
|
||||
|
||||
ans = []
|
||||
for i in range(N):
|
||||
start = supervision_segments[i, 1]
|
||||
end = start + supervision_segments[i, 2]
|
||||
|
||||
feat = features[i, start:end, :]
|
||||
|
||||
r = torch.randint(low=0, high=len(frame_shifts), size=(1,)).item()
|
||||
frame_shift = frame_shifts[r]
|
||||
|
||||
# You can enable the following debug statement
|
||||
# and run ./transducer_stateless/test_frame_shift.py to
|
||||
# view the debug output.
|
||||
# print("frame_shift", frame_shift)
|
||||
|
||||
feat = torch.roll(feat, shifts=frame_shift, dims=0)
|
||||
ans.append(feat)
|
||||
|
||||
ans = torch.nn.utils.rnn.pad_sequence(
|
||||
ans,
|
||||
batch_first=True,
|
||||
padding_value=LOG_EPSILON,
|
||||
)
|
||||
assert features.shape == ans.shape
|
||||
|
||||
return ans
|
70
egs/librispeech/ASR/transducer_stateless/test_frame_shift.py
Executable file
70
egs/librispeech/ASR/transducer_stateless/test_frame_shift.py
Executable file
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./transducer_stateless/test_frame_shift.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from frame_shift import apply_frame_shift
|
||||
|
||||
|
||||
def test_apply_frame_shift():
|
||||
features = torch.tensor(
|
||||
[
|
||||
[
|
||||
[1, 2, 5],
|
||||
[2, 6, 9],
|
||||
[3, 0, 2],
|
||||
[4, 11, 13],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
],
|
||||
[
|
||||
[1, 3, 9],
|
||||
[2, 5, 8],
|
||||
[3, 3, 6],
|
||||
[4, 0, 3],
|
||||
[5, 1, 2],
|
||||
[6, 6, 6],
|
||||
],
|
||||
]
|
||||
)
|
||||
supervision_segments = torch.tensor(
|
||||
[
|
||||
[0, 0, 4],
|
||||
[1, 0, 6],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
shifted_features = apply_frame_shift(features, supervision_segments)
|
||||
|
||||
# You can enable the debug statement in frame_shift.py
|
||||
# and check the resulting shifted_features. I've verified
|
||||
# manually that it is correct.
|
||||
print(shifted_features)
|
||||
|
||||
|
||||
def main():
|
||||
test_apply_frame_shift()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -46,6 +46,7 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from frame_shift import apply_frame_shift
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
@ -138,6 +139,13 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--apply-frame-shift",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If enabled, apply random frame shift along the time axis",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -620,7 +628,17 @@ def run(rank, world_size, args):
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
if params.apply_frame_shift:
|
||||
logging.info("Enable random frame shift")
|
||||
extra_input_transforms = [apply_frame_shift]
|
||||
else:
|
||||
logging.info("Disable random frame shift")
|
||||
extra_input_transforms = None
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts,
|
||||
extra_input_transforms=extra_input_transforms,
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
|
Loading…
x
Reference in New Issue
Block a user