From 89a08b64cef3c654f04ace409f7d3ab28d6a0a21 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 13 Dec 2021 16:41:14 +0800 Subject: [PATCH] Remove long utterances to avoid OOM when a large max_duraiton is used. --- .../ASR/local/display_manifest_statistics.py | 215 ++++++++++++++++++ egs/librispeech/ASR/transducer/train.py | 30 ++- 2 files changed, 240 insertions(+), 5 deletions(-) create mode 100755 egs/librispeech/ASR/local/display_manifest_statistics.py diff --git a/egs/librispeech/ASR/local/display_manifest_statistics.py b/egs/librispeech/ASR/local/display_manifest_statistics.py new file mode 100755 index 000000000..15bd206fa --- /dev/null +++ b/egs/librispeech/ASR/local/display_manifest_statistics.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. +""" + + +from lhotse import load_manifest + + +def main(): + path = "./data/fbank/cuts_train-clean-100.json.gz" + path = "./data/fbank/cuts_train-clean-360.json.gz" + path = "./data/fbank/cuts_train-other-500.json.gz" + path = "./data/fbank/cuts_dev-clean.json.gz" + path = "./data/fbank/cuts_dev-other.json.gz" + path = "./data/fbank/cuts_test-clean.json.gz" + path = "./data/fbank/cuts_test-other.json.gz" + + cuts = load_manifest(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +## train-clean-100 +Cuts count: 85617 +Total duration (hours): 303.8 +Speech duration (hours): 303.8 (100.0%) +*** +Duration statistics (seconds): +mean 12.8 +std 3.8 +min 1.3 +0.1% 1.9 +0.5% 2.2 +1% 2.5 +5% 4.2 +10% 6.4 +25% 11.4 +50% 13.8 +75% 15.3 +90% 16.7 +95% 17.3 +99% 18.1 +99.5% 18.4 +99.9% 18.8 +max 27.2 + +## train-clean-360 +Cuts count: 312042 +Total duration (hours): 1098.2 +Speech duration (hours): 1098.2 (100.0%) +*** +Duration statistics (seconds): +mean 12.7 +std 3.8 +min 1.0 +0.1% 1.8 +0.5% 2.2 +1% 2.5 +5% 4.2 +10% 6.2 +25% 11.2 +50% 13.7 +75% 15.3 +90% 16.6 +95% 17.3 +99% 18.1 +99.5% 18.4 +99.9% 18.8 +max 33.0 + +## train-other 500 +Cuts count: 446064 +Total duration (hours): 1500.6 +Speech duration (hours): 1500.6 (100.0%) +*** +Duration statistics (seconds): +mean 12.1 +std 4.2 +min 0.8 +0.1% 1.7 +0.5% 2.1 +1% 2.3 +5% 3.5 +10% 5.0 +25% 9.8 +50% 13.4 +75% 15.1 +90% 16.5 +95% 17.2 +99% 18.1 +99.5% 18.4 +99.9% 18.9 +max 31.0 + +## dev-clean +Cuts count: 2703 +Total duration (hours): 5.4 +Speech duration (hours): 5.4 (100.0%) +*** +Duration statistics (seconds): +mean 7.2 +std 4.7 +min 1.4 +0.1% 1.6 +0.5% 1.8 +1% 1.9 +5% 2.4 +10% 2.7 +25% 3.8 +50% 5.9 +75% 9.3 +90% 13.3 +95% 16.4 +99% 23.8 +99.5% 28.5 +99.9% 32.3 +max 32.6 + +## dev-other +Cuts count: 2864 +Total duration (hours): 5.1 +Speech duration (hours): 5.1 (100.0%) +*** +Duration statistics (seconds): +mean 6.4 +std 4.3 +min 1.1 +0.1% 1.3 +0.5% 1.7 +1% 1.8 +5% 2.2 +10% 2.6 +25% 3.5 +50% 5.3 +75% 7.9 +90% 12.0 +95% 15.0 +99% 22.2 +99.5% 27.1 +99.9% 32.4 +max 35.2 + +## test-clean +Cuts count: 2620 +Total duration (hours): 5.4 +Speech duration (hours): 5.4 (100.0%) +*** +Duration statistics (seconds): +mean 7.4 +std 5.2 +min 1.3 +0.1% 1.6 +0.5% 1.8 +1% 2.0 +5% 2.3 +10% 2.7 +25% 3.7 +50% 5.8 +75% 9.6 +90% 14.6 +95% 17.8 +99% 25.5 +99.5% 28.4 +99.9% 32.8 +max 35.0 + +## test-other +Cuts count: 2939 +Total duration (hours): 5.3 +Speech duration (hours): 5.3 (100.0%) +*** +Duration statistics (seconds): +mean 6.5 +std 4.4 +min 1.2 +0.1% 1.5 +0.5% 1.8 +1% 1.9 +5% 2.3 +10% 2.6 +25% 3.4 +50% 5.2 +75% 8.2 +90% 12.6 +95% 15.8 +99% 21.4 +99.5% 23.8 +99.9% 33.5 +max 34.5 +""" diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index b5dbe02e9..c80ea4bbc 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -30,6 +30,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from lhotse.cut import Cut from lhotse.utils import fix_random_seed from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP @@ -176,7 +177,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 3000, + "valid_interval": 3000, # For the 100h subset, use 800 # parameters for conformer "feature_dim": 80, "encoder_out_dim": 512, @@ -193,7 +194,7 @@ def get_params() -> AttributeDict: "decoder_hidden_dim": 512, # parameters for Noam "weight_decay": 1e-6, - "warm_step": 80000, + "warm_step": 80000, # For the 100h subset, use 8k "env_info": get_env_info(), } ) @@ -382,9 +383,8 @@ def compute_loss( info = MetricsTracker() info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - # We use reduction="sum" in computing the loss. - # The displayed loss is the average loss over the batch - info["loss"] = loss.detach().cpu().item() / feature.size(0) + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() return loss, info @@ -535,6 +535,9 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 800 + params.warm_step = 8000 fix_random_seed(42) if world_size > 1: @@ -592,6 +595,23 @@ def run(rank, world_size, args): if params.full_libri: train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + return 1.0 <= c.duration <= 20.0 + + num_in_total = len(train_cuts) + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + num_left = len(train_cuts) + num_removed = num_in_total - num_left + removed_percent = num_removed / num_in_total * 100 + + logging.info(f"Before removing short and long utterances: {num_in_total}") + logging.info(f"After removing short and long utterances: {num_left}") + logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + train_dl = librispeech.train_dataloaders(train_cuts) valid_cuts = librispeech.dev_clean_cuts()