From ec5a112831b54471daa82ddb956112b88128577a Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Fri, 20 May 2022 19:30:38 +0800 Subject: [PATCH] [Ready to merge] Do some coding style checks for the latest files (#379) * style check * do changes for .flake8 * a change for compute_fbank_yesno.py --- .flake8 | 14 ++---- .../ASR/local/compute_fbank_musan.py | 21 ++++---- .../ASR/local/compute_fbank_spgispeech.py | 28 ++++++----- egs/spgispeech/ASR/local/prepare_splits.py | 10 ++-- egs/yesno/ASR/local/compute_fbank_yesno.py | 4 +- icefall/diagnostics.py | 49 ++++++++++++------- 6 files changed, 74 insertions(+), 52 deletions(-) diff --git a/.flake8 b/.flake8 index 8c497fac3..dbeec0b0c 100644 --- a/.flake8 +++ b/.flake8 @@ -4,15 +4,11 @@ statistics=true max-line-length = 80 per-file-ignores = # line too long - egs/librispeech/ASR/*/conformer.py: E501, - egs/aishell/ASR/*/conformer.py: E501, - egs/tedlium3/ASR/*/conformer.py: E501, - egs/gigaspeech/ASR/*/conformer.py: E501, - egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501, - egs/gigaspeech/ASR/pruned_transducer_stateless2/*.py: E501, - egs/librispeech/ASR/pruned_transducer_stateless4/*.py: E501, - egs/librispeech/ASR/*/optim.py: E501, - egs/librispeech/ASR/*/scaling.py: E501, + icefall/diagnostics.py: E501 + egs/*/ASR/*/conformer.py: E501, + egs/*/ASR/pruned_transducer_stateless*/*.py: E501, + egs/*/ASR/*/optim.py: E501, + egs/*/ASR/*/scaling.py: E501, # invalid escape sequence (cause by tex formular), W605 icefall/utils.py: E501, W605 diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py index b4f409ba6..57805a756 100755 --- a/egs/spgispeech/ASR/local/compute_fbank_musan.py +++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py @@ -27,17 +27,15 @@ import logging from pathlib import Path import torch -from lhotse import LilcomChunkyWriter, CutSet, combine +from lhotse import CutSet, LilcomChunkyWriter, combine from lhotse.features.kaldifeat import ( KaldifeatFbank, KaldifeatFbankConfig, - KaldifeatMelOptions, KaldifeatFrameOptions, + KaldifeatMelOptions, ) from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor - # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -82,23 +80,28 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) + recordings=combine( + part["recordings"] for part in manifests.values() + ) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) .compute_and_store_features_batch( extractor=extractor, - storage_path=output_dir / f"feats_musan", - manifest_path=src_dir / f"cuts_musan.jsonl.gz", + storage_path=output_dir / "feats_musan", batch_duration=500, num_workers=4, storage_type=LilcomChunkyWriter, ) ) + logging.info(f"Saving to {musan_cuts_path}") + musan_cuts.to_file(musan_cuts_path) + if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py index cc8c8a670..b67754e2a 100755 --- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py +++ b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py @@ -25,17 +25,15 @@ The generated fbank features are saved in data/fbank. import argparse import logging from pathlib import Path -from tqdm import tqdm import torch -from lhotse import load_manifest_lazy, LilcomChunkyWriter +from lhotse import LilcomChunkyWriter, load_manifest_lazy from lhotse.features.kaldifeat import ( KaldifeatFbank, KaldifeatFbankConfig, - KaldifeatMelOptions, KaldifeatFrameOptions, + KaldifeatMelOptions, ) -from lhotse.manipulation import combine # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -97,27 +95,32 @@ def compute_fbank_spgispeech(args): ) if args.train: - logging.info(f"Processing train") - cut_set = load_manifest_lazy(src_dir / f"cuts_train_raw.jsonl.gz") + logging.info("Processing train") + cut_set = load_manifest_lazy(src_dir / "cuts_train_raw.jsonl.gz") chunk_size = len(cut_set) // args.num_splits cut_sets = cut_set.split_lazy( output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}", chunk_size=chunk_size, ) start = args.start - stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits + stop = ( + min(args.stop, args.num_splits) + if args.stop > 0 + else args.num_splits + ) num_digits = len(str(args.num_splits)) for i in range(start, stop): idx = f"{i + 1}".zfill(num_digits) + cuts_train_idx_path = src_dir / f"cuts_train_{idx}.jsonl.gz" logging.info(f"Processing train split {i}") cs = cut_sets[i].compute_and_store_features_batch( extractor=extractor, storage_path=output_dir / f"feats_train_{idx}", - manifest_path=src_dir / f"cuts_train_{idx}.jsonl.gz", batch_duration=500, num_workers=4, storage_type=LilcomChunkyWriter, ) + cs.to_file(cuts_train_idx_path) if args.test: for partition in ["dev", "val"]: @@ -125,7 +128,9 @@ def compute_fbank_spgispeech(args): logging.info(f"{partition} already exists - skipping.") continue logging.info(f"Processing {partition}") - cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz") + cut_set = load_manifest_lazy( + src_dir / f"cuts_{partition}_raw.jsonl.gz" + ) cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, storage_path=output_dir / f"feats_{partition}", @@ -137,8 +142,9 @@ def compute_fbank_spgispeech(args): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py index 2d1818649..8c8f1c133 100755 --- a/egs/spgispeech/ASR/local/prepare_splits.py +++ b/egs/spgispeech/ASR/local/prepare_splits.py @@ -24,7 +24,6 @@ from pathlib import Path import torch from lhotse import CutSet - from lhotse.recipes.utils import read_manifests_if_cached # Torch's multithreaded behavior needs to be disabled or @@ -56,7 +55,9 @@ def split_spgispeech_train(): # Add speed perturbation train_cuts = ( - train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1) + train_cuts + + train_cuts.perturb_speed(0.9) + + train_cuts.perturb_speed(1.1) ) # Write the manifests to disk. @@ -72,8 +73,9 @@ def split_spgispeech_train(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) split_spgispeech_train() diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py index 6072d4222..6922ffe10 100755 --- a/egs/yesno/ASR/local/compute_fbank_yesno.py +++ b/egs/yesno/ASR/local/compute_fbank_yesno.py @@ -38,7 +38,9 @@ def compute_fbank_yesno(): "test", ) manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, output_dir=src_dir + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix="yesno", ) assert manifests is not None diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index e8bedc64e..4850308d9 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -18,8 +18,9 @@ import random -from typing import List, Optional, Tuple from dataclasses import dataclass +from typing import Optional, Tuple + import torch from torch import Tensor, nn @@ -90,8 +91,6 @@ def get_tensor_stats( return x, count - - @dataclass class TensorAndCount: tensor: Tensor @@ -108,12 +107,12 @@ class TensorDiagnostic(object): name: The tensor name. """ + def __init__(self, opts: TensorDiagnosticOptions, name: str): self.name = name self.opts = opts - - self.stats = None # we'll later assign a list to this data member. It's a list of dict. + self.stats = None # we'll later assign a list to this data member. It's a list of dict. # the keys into self.stats[dim] are strings, whose values can be # "abs", "value", "positive", "rms", "value". @@ -125,7 +124,6 @@ class TensorDiagnostic(object): # only adding a new element to the list if there was a different dim. # if the string in the key is "eigs", if we detect a length mismatch we put None as the value. - def accumulate(self, x): """Accumulate tensors.""" if isinstance(x, Tuple): @@ -137,7 +135,7 @@ class TensorDiagnostic(object): x = x.unsqueeze(0) ndim = x.ndim if self.stats is None: - self.stats = [ dict() for _ in range(ndim) ] + self.stats = [dict() for _ in range(ndim)] for dim in range(ndim): this_dim_stats = self.stats[dim] @@ -147,10 +145,10 @@ class TensorDiagnostic(object): stats_types.append("eigs") else: stats_types = ["value", "abs"] - this_dict = self.stats[dim] + for stats_type in stats_types: stats, count = get_tensor_stats(x, dim, stats_type) - if not stats_type in this_dim_stats: + if stats_type not in this_dim_stats: this_dim_stats[stats_type] = [] # list of TensorAndCount done = False @@ -166,13 +164,17 @@ class TensorDiagnostic(object): done = True break if not done: - if this_dim_stats[stats_type] != [] and stats_type == "eigs": + if ( + this_dim_stats[stats_type] != [] + and stats_type == "eigs" + ): # >1 size encountered on this dim, e.g. it's a batch or time dimension, # don't accumulat "eigs" stats type, it uses too much memory this_dim_stats[stats_type] = None else: - this_dim_stats[stats_type].append(TensorAndCount(stats, count)) - + this_dim_stats[stats_type].append( + TensorAndCount(stats, count) + ) def print_diagnostics(self): """Print diagnostics for each dimension of the tensor.""" @@ -191,14 +193,18 @@ class TensorDiagnostic(object): eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print("Error getting eigenvalues, trying another method.") + print( + "Error getting eigenvalues, trying another method." + ) eigs = torch.linalg.eigvals(stats) stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance elif len(stats_list) == 1: stats = stats_list[0].tensor / stats_list[0].count else: - stats = torch.cat([x.tensor / x.count for x in stats_list], dim=0) + stats = torch.cat( + [x.tensor / x.count for x in stats_list], dim=0 + ) if stats_type == "rms": # we stored the square; after aggregation we need to take sqrt. @@ -206,7 +212,9 @@ class TensorDiagnostic(object): # if `summarize` we print percentiles of the stats; else, # we print out individual elements. - summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(stats.numel()) + summarize = ( + len(stats_list) > 1 + ) or self.opts.dim_is_summarized(stats.numel()) if summarize: # usually `summarize` will be true # print out percentiles. stats = stats.sort()[0] @@ -238,9 +246,14 @@ class TensorDiagnostic(object): # ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5" sizes = [x.tensor.shape[0] for x in stats_list] - size_str = f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" - print(f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}") - + size_str = ( + f"{sizes[0]}" + if len(sizes) == 1 + else f"{min(sizes)}..{max(sizes)}" + ) + print( + f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}" + ) class ModelDiagnostic(object):