mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Remove long utterances to avoid OOM when a large max_duraiton is used.
This commit is contained in:
parent
cd5ed7db20
commit
89a08b64ce
215
egs/librispeech/ASR/local/display_manifest_statistics.py
Executable file
215
egs/librispeech/ASR/local/display_manifest_statistics.py
Executable file
@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file displays duration statistics of utterances in a manifest.
|
||||
You can use the displayed value to choose minimum/maximum duration
|
||||
to remove short and long utterances during the training.
|
||||
|
||||
See the function `remove_short_and_long_utt()` in transducer/train.py
|
||||
for usage.
|
||||
"""
|
||||
|
||||
|
||||
from lhotse import load_manifest
|
||||
|
||||
|
||||
def main():
|
||||
path = "./data/fbank/cuts_train-clean-100.json.gz"
|
||||
path = "./data/fbank/cuts_train-clean-360.json.gz"
|
||||
path = "./data/fbank/cuts_train-other-500.json.gz"
|
||||
path = "./data/fbank/cuts_dev-clean.json.gz"
|
||||
path = "./data/fbank/cuts_dev-other.json.gz"
|
||||
path = "./data/fbank/cuts_test-clean.json.gz"
|
||||
path = "./data/fbank/cuts_test-other.json.gz"
|
||||
|
||||
cuts = load_manifest(path)
|
||||
cuts.describe()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
## train-clean-100
|
||||
Cuts count: 85617
|
||||
Total duration (hours): 303.8
|
||||
Speech duration (hours): 303.8 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 12.8
|
||||
std 3.8
|
||||
min 1.3
|
||||
0.1% 1.9
|
||||
0.5% 2.2
|
||||
1% 2.5
|
||||
5% 4.2
|
||||
10% 6.4
|
||||
25% 11.4
|
||||
50% 13.8
|
||||
75% 15.3
|
||||
90% 16.7
|
||||
95% 17.3
|
||||
99% 18.1
|
||||
99.5% 18.4
|
||||
99.9% 18.8
|
||||
max 27.2
|
||||
|
||||
## train-clean-360
|
||||
Cuts count: 312042
|
||||
Total duration (hours): 1098.2
|
||||
Speech duration (hours): 1098.2 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 12.7
|
||||
std 3.8
|
||||
min 1.0
|
||||
0.1% 1.8
|
||||
0.5% 2.2
|
||||
1% 2.5
|
||||
5% 4.2
|
||||
10% 6.2
|
||||
25% 11.2
|
||||
50% 13.7
|
||||
75% 15.3
|
||||
90% 16.6
|
||||
95% 17.3
|
||||
99% 18.1
|
||||
99.5% 18.4
|
||||
99.9% 18.8
|
||||
max 33.0
|
||||
|
||||
## train-other 500
|
||||
Cuts count: 446064
|
||||
Total duration (hours): 1500.6
|
||||
Speech duration (hours): 1500.6 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 12.1
|
||||
std 4.2
|
||||
min 0.8
|
||||
0.1% 1.7
|
||||
0.5% 2.1
|
||||
1% 2.3
|
||||
5% 3.5
|
||||
10% 5.0
|
||||
25% 9.8
|
||||
50% 13.4
|
||||
75% 15.1
|
||||
90% 16.5
|
||||
95% 17.2
|
||||
99% 18.1
|
||||
99.5% 18.4
|
||||
99.9% 18.9
|
||||
max 31.0
|
||||
|
||||
## dev-clean
|
||||
Cuts count: 2703
|
||||
Total duration (hours): 5.4
|
||||
Speech duration (hours): 5.4 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 7.2
|
||||
std 4.7
|
||||
min 1.4
|
||||
0.1% 1.6
|
||||
0.5% 1.8
|
||||
1% 1.9
|
||||
5% 2.4
|
||||
10% 2.7
|
||||
25% 3.8
|
||||
50% 5.9
|
||||
75% 9.3
|
||||
90% 13.3
|
||||
95% 16.4
|
||||
99% 23.8
|
||||
99.5% 28.5
|
||||
99.9% 32.3
|
||||
max 32.6
|
||||
|
||||
## dev-other
|
||||
Cuts count: 2864
|
||||
Total duration (hours): 5.1
|
||||
Speech duration (hours): 5.1 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 6.4
|
||||
std 4.3
|
||||
min 1.1
|
||||
0.1% 1.3
|
||||
0.5% 1.7
|
||||
1% 1.8
|
||||
5% 2.2
|
||||
10% 2.6
|
||||
25% 3.5
|
||||
50% 5.3
|
||||
75% 7.9
|
||||
90% 12.0
|
||||
95% 15.0
|
||||
99% 22.2
|
||||
99.5% 27.1
|
||||
99.9% 32.4
|
||||
max 35.2
|
||||
|
||||
## test-clean
|
||||
Cuts count: 2620
|
||||
Total duration (hours): 5.4
|
||||
Speech duration (hours): 5.4 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 7.4
|
||||
std 5.2
|
||||
min 1.3
|
||||
0.1% 1.6
|
||||
0.5% 1.8
|
||||
1% 2.0
|
||||
5% 2.3
|
||||
10% 2.7
|
||||
25% 3.7
|
||||
50% 5.8
|
||||
75% 9.6
|
||||
90% 14.6
|
||||
95% 17.8
|
||||
99% 25.5
|
||||
99.5% 28.4
|
||||
99.9% 32.8
|
||||
max 35.0
|
||||
|
||||
## test-other
|
||||
Cuts count: 2939
|
||||
Total duration (hours): 5.3
|
||||
Speech duration (hours): 5.3 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 6.5
|
||||
std 4.4
|
||||
min 1.2
|
||||
0.1% 1.5
|
||||
0.5% 1.8
|
||||
1% 1.9
|
||||
5% 2.3
|
||||
10% 2.6
|
||||
25% 3.4
|
||||
50% 5.2
|
||||
75% 8.2
|
||||
90% 12.6
|
||||
95% 15.8
|
||||
99% 21.4
|
||||
99.5% 23.8
|
||||
99.9% 33.5
|
||||
max 34.5
|
||||
"""
|
@ -30,6 +30,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -176,7 +177,7 @@ def get_params() -> AttributeDict:
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"encoder_out_dim": 512,
|
||||
@ -193,7 +194,7 @@ def get_params() -> AttributeDict:
|
||||
"decoder_hidden_dim": 512,
|
||||
# parameters for Noam
|
||||
"weight_decay": 1e-6,
|
||||
"warm_step": 80000,
|
||||
"warm_step": 80000, # For the 100h subset, use 8k
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -382,9 +383,8 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# We use reduction="sum" in computing the loss.
|
||||
# The displayed loss is the average loss over the batch
|
||||
info["loss"] = loss.detach().cpu().item() / feature.size(0)
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -535,6 +535,9 @@ def run(rank, world_size, args):
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
if params.full_libri is False:
|
||||
params.valid_interval = 800
|
||||
params.warm_step = 8000
|
||||
|
||||
fix_random_seed(42)
|
||||
if world_size > 1:
|
||||
@ -592,6 +595,23 @@ def run(rank, world_size, args):
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
return 1.0 <= c.duration <= 20.0
|
||||
|
||||
num_in_total = len(train_cuts)
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
num_left = len(train_cuts)
|
||||
num_removed = num_in_total - num_left
|
||||
removed_percent = num_removed / num_in_total * 100
|
||||
|
||||
logging.info(f"Before removing short and long utterances: {num_in_total}")
|
||||
logging.info(f"After removing short and long utterances: {num_left}")
|
||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
|
Loading…
x
Reference in New Issue
Block a user