mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Remove long utterances to avoid OOM when a large max_duraiton is used.
This commit is contained in:
parent
cd5ed7db20
commit
89a08b64ce
215
egs/librispeech/ASR/local/display_manifest_statistics.py
Executable file
215
egs/librispeech/ASR/local/display_manifest_statistics.py
Executable file
@ -0,0 +1,215 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file displays duration statistics of utterances in a manifest.
|
||||||
|
You can use the displayed value to choose minimum/maximum duration
|
||||||
|
to remove short and long utterances during the training.
|
||||||
|
|
||||||
|
See the function `remove_short_and_long_utt()` in transducer/train.py
|
||||||
|
for usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from lhotse import load_manifest
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
path = "./data/fbank/cuts_train-clean-100.json.gz"
|
||||||
|
path = "./data/fbank/cuts_train-clean-360.json.gz"
|
||||||
|
path = "./data/fbank/cuts_train-other-500.json.gz"
|
||||||
|
path = "./data/fbank/cuts_dev-clean.json.gz"
|
||||||
|
path = "./data/fbank/cuts_dev-other.json.gz"
|
||||||
|
path = "./data/fbank/cuts_test-clean.json.gz"
|
||||||
|
path = "./data/fbank/cuts_test-other.json.gz"
|
||||||
|
|
||||||
|
cuts = load_manifest(path)
|
||||||
|
cuts.describe()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
"""
|
||||||
|
## train-clean-100
|
||||||
|
Cuts count: 85617
|
||||||
|
Total duration (hours): 303.8
|
||||||
|
Speech duration (hours): 303.8 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 12.8
|
||||||
|
std 3.8
|
||||||
|
min 1.3
|
||||||
|
0.1% 1.9
|
||||||
|
0.5% 2.2
|
||||||
|
1% 2.5
|
||||||
|
5% 4.2
|
||||||
|
10% 6.4
|
||||||
|
25% 11.4
|
||||||
|
50% 13.8
|
||||||
|
75% 15.3
|
||||||
|
90% 16.7
|
||||||
|
95% 17.3
|
||||||
|
99% 18.1
|
||||||
|
99.5% 18.4
|
||||||
|
99.9% 18.8
|
||||||
|
max 27.2
|
||||||
|
|
||||||
|
## train-clean-360
|
||||||
|
Cuts count: 312042
|
||||||
|
Total duration (hours): 1098.2
|
||||||
|
Speech duration (hours): 1098.2 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 12.7
|
||||||
|
std 3.8
|
||||||
|
min 1.0
|
||||||
|
0.1% 1.8
|
||||||
|
0.5% 2.2
|
||||||
|
1% 2.5
|
||||||
|
5% 4.2
|
||||||
|
10% 6.2
|
||||||
|
25% 11.2
|
||||||
|
50% 13.7
|
||||||
|
75% 15.3
|
||||||
|
90% 16.6
|
||||||
|
95% 17.3
|
||||||
|
99% 18.1
|
||||||
|
99.5% 18.4
|
||||||
|
99.9% 18.8
|
||||||
|
max 33.0
|
||||||
|
|
||||||
|
## train-other 500
|
||||||
|
Cuts count: 446064
|
||||||
|
Total duration (hours): 1500.6
|
||||||
|
Speech duration (hours): 1500.6 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 12.1
|
||||||
|
std 4.2
|
||||||
|
min 0.8
|
||||||
|
0.1% 1.7
|
||||||
|
0.5% 2.1
|
||||||
|
1% 2.3
|
||||||
|
5% 3.5
|
||||||
|
10% 5.0
|
||||||
|
25% 9.8
|
||||||
|
50% 13.4
|
||||||
|
75% 15.1
|
||||||
|
90% 16.5
|
||||||
|
95% 17.2
|
||||||
|
99% 18.1
|
||||||
|
99.5% 18.4
|
||||||
|
99.9% 18.9
|
||||||
|
max 31.0
|
||||||
|
|
||||||
|
## dev-clean
|
||||||
|
Cuts count: 2703
|
||||||
|
Total duration (hours): 5.4
|
||||||
|
Speech duration (hours): 5.4 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 7.2
|
||||||
|
std 4.7
|
||||||
|
min 1.4
|
||||||
|
0.1% 1.6
|
||||||
|
0.5% 1.8
|
||||||
|
1% 1.9
|
||||||
|
5% 2.4
|
||||||
|
10% 2.7
|
||||||
|
25% 3.8
|
||||||
|
50% 5.9
|
||||||
|
75% 9.3
|
||||||
|
90% 13.3
|
||||||
|
95% 16.4
|
||||||
|
99% 23.8
|
||||||
|
99.5% 28.5
|
||||||
|
99.9% 32.3
|
||||||
|
max 32.6
|
||||||
|
|
||||||
|
## dev-other
|
||||||
|
Cuts count: 2864
|
||||||
|
Total duration (hours): 5.1
|
||||||
|
Speech duration (hours): 5.1 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 6.4
|
||||||
|
std 4.3
|
||||||
|
min 1.1
|
||||||
|
0.1% 1.3
|
||||||
|
0.5% 1.7
|
||||||
|
1% 1.8
|
||||||
|
5% 2.2
|
||||||
|
10% 2.6
|
||||||
|
25% 3.5
|
||||||
|
50% 5.3
|
||||||
|
75% 7.9
|
||||||
|
90% 12.0
|
||||||
|
95% 15.0
|
||||||
|
99% 22.2
|
||||||
|
99.5% 27.1
|
||||||
|
99.9% 32.4
|
||||||
|
max 35.2
|
||||||
|
|
||||||
|
## test-clean
|
||||||
|
Cuts count: 2620
|
||||||
|
Total duration (hours): 5.4
|
||||||
|
Speech duration (hours): 5.4 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 7.4
|
||||||
|
std 5.2
|
||||||
|
min 1.3
|
||||||
|
0.1% 1.6
|
||||||
|
0.5% 1.8
|
||||||
|
1% 2.0
|
||||||
|
5% 2.3
|
||||||
|
10% 2.7
|
||||||
|
25% 3.7
|
||||||
|
50% 5.8
|
||||||
|
75% 9.6
|
||||||
|
90% 14.6
|
||||||
|
95% 17.8
|
||||||
|
99% 25.5
|
||||||
|
99.5% 28.4
|
||||||
|
99.9% 32.8
|
||||||
|
max 35.0
|
||||||
|
|
||||||
|
## test-other
|
||||||
|
Cuts count: 2939
|
||||||
|
Total duration (hours): 5.3
|
||||||
|
Speech duration (hours): 5.3 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 6.5
|
||||||
|
std 4.4
|
||||||
|
min 1.2
|
||||||
|
0.1% 1.5
|
||||||
|
0.5% 1.8
|
||||||
|
1% 1.9
|
||||||
|
5% 2.3
|
||||||
|
10% 2.6
|
||||||
|
25% 3.4
|
||||||
|
50% 5.2
|
||||||
|
75% 8.2
|
||||||
|
90% 12.6
|
||||||
|
95% 15.8
|
||||||
|
99% 21.4
|
||||||
|
99.5% 23.8
|
||||||
|
99.9% 33.5
|
||||||
|
max 34.5
|
||||||
|
"""
|
@ -30,6 +30,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@ -176,7 +177,7 @@ def get_params() -> AttributeDict:
|
|||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000,
|
"valid_interval": 3000, # For the 100h subset, use 800
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"encoder_out_dim": 512,
|
"encoder_out_dim": 512,
|
||||||
@ -193,7 +194,7 @@ def get_params() -> AttributeDict:
|
|||||||
"decoder_hidden_dim": 512,
|
"decoder_hidden_dim": 512,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"weight_decay": 1e-6,
|
"weight_decay": 1e-6,
|
||||||
"warm_step": 80000,
|
"warm_step": 80000, # For the 100h subset, use 8k
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -382,9 +383,8 @@ def compute_loss(
|
|||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
|
||||||
# We use reduction="sum" in computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
# The displayed loss is the average loss over the batch
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["loss"] = loss.detach().cpu().item() / feature.size(0)
|
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -535,6 +535,9 @@ def run(rank, world_size, args):
|
|||||||
"""
|
"""
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
if params.full_libri is False:
|
||||||
|
params.valid_interval = 800
|
||||||
|
params.warm_step = 8000
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(42)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -592,6 +595,23 @@ def run(rank, world_size, args):
|
|||||||
if params.full_libri:
|
if params.full_libri:
|
||||||
train_cuts += librispeech.train_clean_360_cuts()
|
train_cuts += librispeech.train_clean_360_cuts()
|
||||||
train_cuts += librispeech.train_other_500_cuts()
|
train_cuts += librispeech.train_other_500_cuts()
|
||||||
|
|
||||||
|
def remove_short_and_long_utt(c: Cut):
|
||||||
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
|
return 1.0 <= c.duration <= 20.0
|
||||||
|
|
||||||
|
num_in_total = len(train_cuts)
|
||||||
|
|
||||||
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
|
num_left = len(train_cuts)
|
||||||
|
num_removed = num_in_total - num_left
|
||||||
|
removed_percent = num_removed / num_in_total * 100
|
||||||
|
|
||||||
|
logging.info(f"Before removing short and long utterances: {num_in_total}")
|
||||||
|
logging.info(f"After removing short and long utterances: {num_left}")
|
||||||
|
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user