From d68b8e91202d02924f7df6e5c843d03425657f33 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 28 Aug 2022 11:17:38 +0800 Subject: [PATCH] Disable CUDA_LAUNCH_BLOCKING in wenetspeech recipes. (#554) * Disable CUDA_LAUNCH_BLOCKING in wenetspeech recipes. * minor fixes --- .../ASR/local/preprocess_wenetspeech.py | 25 ++++++++++++------- .../ASR/pruned_transducer_stateless2/train.py | 14 ++++------- .../ASR/pruned_transducer_stateless5/train.py | 5 +--- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py index 64733eb15..f4c71230b 100755 --- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py +++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py @@ -23,6 +23,8 @@ from pathlib import Path from lhotse import CutSet, SupervisionSegment from lhotse.recipes.utils import read_manifests_if_cached +from icefall import setup_logger + # Similar text filtering and normalization procedure as in: # https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh @@ -48,13 +50,17 @@ def preprocess_wenet_speech(): output_dir = Path("data/fbank") output_dir.mkdir(exist_ok=True) + # Note: By default, we preprocess all sub-parts. + # You can delete those that you don't need. + # For instance, if you don't want to use the L subpart, just remove + # the line below containing "L" dataset_parts = ( - "L", - "M", - "S", "DEV", "TEST_NET", "TEST_MEETING", + "S", + "M", + "L", ) logging.info("Loading manifest (may take 10 minutes)") @@ -81,10 +87,13 @@ def preprocess_wenet_speech(): logging.info(f"Normalizing text in {partition}") for sup in m["supervisions"]: text = str(sup.text) - logging.info(f"Original text: {text}") + orig_text = text sup.text = normalize_text(sup.text) text = str(sup.text) - logging.info(f"Normalize text: {text}") + if len(orig_text) != len(text): + logging.info( + f"\nOriginal text vs normalized text:\n{orig_text}\n{text}" + ) # Create long-recording cut manifests. logging.info(f"Processing {partition}") @@ -109,12 +118,10 @@ def preprocess_wenet_speech(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - logging.basicConfig(format=formatter, level=logging.INFO) + setup_logger(log_filename="./log-preprocess-wenetspeech") preprocess_wenet_speech() + logging.info("Done") if __name__ == "__main__": diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 5208dbefe..d3cc7c9c9 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -81,7 +81,6 @@ For training with the S subset: import argparse import logging -import os import warnings from pathlib import Path from shutil import copyfile @@ -120,8 +119,6 @@ LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler ] -os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - def get_parser(): parser = argparse.ArgumentParser( @@ -162,7 +159,7 @@ def get_parser(): default=0, help="""Resume training from from this epoch. If it is positive, it will load checkpoint from - transducer_stateless2/exp/epoch-{start_epoch-1}.pt + pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt """, ) @@ -361,8 +358,8 @@ def get_params() -> AttributeDict: "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, - "batch_idx_train": 10, - "log_interval": 1, + "batch_idx_train": 0, + "log_interval": 50, "reset_interval": 200, # parameters for conformer "feature_dim": 80, @@ -545,7 +542,7 @@ def compute_loss( warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute CTC loss given the model and its inputs. + Compute RNN-T loss given the model and its inputs. Args: params: Parameters for training. See :func:`get_params`. @@ -573,7 +570,7 @@ def compute_loss( texts = batch["supervisions"]["text"] y = graph_compiler.texts_to_ids(texts) - if type(y) == list: + if isinstance(y, list): y = k2.RaggedTensor(y).to(device) else: y = y.to(device) @@ -697,7 +694,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 5a5925d55..2052e9da7 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -61,7 +61,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" import argparse import copy import logging -import os import warnings from pathlib import Path from shutil import copyfile @@ -103,8 +102,6 @@ LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler ] -os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( @@ -684,7 +681,7 @@ def compute_loss( texts = batch["supervisions"]["text"] y = graph_compiler.texts_to_ids(texts) - if type(y) == list: + if isinstance(y, list): y = k2.RaggedTensor(y).to(device) else: y = y.to(device)