[Ready to merge] Do some coding style checks for the latest files (#379)

* style check

* do changes for .flake8

* a change for compute_fbank_yesno.py
This commit is contained in:
Mingshuang Luo 2022-05-20 19:30:38 +08:00 committed by GitHub
parent 2900ed8f8f
commit ec5a112831
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 74 additions and 52 deletions

14
.flake8
View File

@ -4,15 +4,11 @@ statistics=true
max-line-length = 80
per-file-ignores =
# line too long
egs/librispeech/ASR/*/conformer.py: E501,
egs/aishell/ASR/*/conformer.py: E501,
egs/tedlium3/ASR/*/conformer.py: E501,
egs/gigaspeech/ASR/*/conformer.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501,
egs/gigaspeech/ASR/pruned_transducer_stateless2/*.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless4/*.py: E501,
egs/librispeech/ASR/*/optim.py: E501,
egs/librispeech/ASR/*/scaling.py: E501,
icefall/diagnostics.py: E501
egs/*/ASR/*/conformer.py: E501,
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
egs/*/ASR/*/optim.py: E501,
egs/*/ASR/*/scaling.py: E501,
# invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605

View File

@ -27,17 +27,15 @@ import logging
from pathlib import Path
import torch
from lhotse import LilcomChunkyWriter, CutSet, combine
from lhotse import CutSet, LilcomChunkyWriter, combine
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatMelOptions,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
@ -82,23 +80,28 @@ def compute_fbank_musan():
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (
CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values())
recordings=combine(
part["recordings"] for part in manifests.values()
)
)
.cut_into_windows(10.0)
.filter(lambda c: c.duration > 5)
.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_musan",
manifest_path=src_dir / f"cuts_musan.jsonl.gz",
storage_path=output_dir / "feats_musan",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
)
logging.info(f"Saving to {musan_cuts_path}")
musan_cuts.to_file(musan_cuts_path)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan()

View File

@ -25,17 +25,15 @@ The generated fbank features are saved in data/fbank.
import argparse
import logging
from pathlib import Path
from tqdm import tqdm
import torch
from lhotse import load_manifest_lazy, LilcomChunkyWriter
from lhotse import LilcomChunkyWriter, load_manifest_lazy
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatMelOptions,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.manipulation import combine
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@ -97,27 +95,32 @@ def compute_fbank_spgispeech(args):
)
if args.train:
logging.info(f"Processing train")
cut_set = load_manifest_lazy(src_dir / f"cuts_train_raw.jsonl.gz")
logging.info("Processing train")
cut_set = load_manifest_lazy(src_dir / "cuts_train_raw.jsonl.gz")
chunk_size = len(cut_set) // args.num_splits
cut_sets = cut_set.split_lazy(
output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}",
chunk_size=chunk_size,
)
start = args.start
stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
stop = (
min(args.stop, args.num_splits)
if args.stop > 0
else args.num_splits
)
num_digits = len(str(args.num_splits))
for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits)
cuts_train_idx_path = src_dir / f"cuts_train_{idx}.jsonl.gz"
logging.info(f"Processing train split {i}")
cs = cut_sets[i].compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_train_{idx}",
manifest_path=src_dir / f"cuts_train_{idx}.jsonl.gz",
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
cs.to_file(cuts_train_idx_path)
if args.test:
for partition in ["dev", "val"]:
@ -125,7 +128,9 @@ def compute_fbank_spgispeech(args):
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
cut_set = load_manifest_lazy(
src_dir / f"cuts_{partition}_raw.jsonl.gz"
)
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"feats_{partition}",
@ -137,8 +142,9 @@ def compute_fbank_spgispeech(args):
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()

View File

@ -24,7 +24,6 @@ from pathlib import Path
import torch
from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached
# Torch's multithreaded behavior needs to be disabled or
@ -56,7 +55,9 @@ def split_spgispeech_train():
# Add speed perturbation
train_cuts = (
train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
train_cuts
+ train_cuts.perturb_speed(0.9)
+ train_cuts.perturb_speed(1.1)
)
# Write the manifests to disk.
@ -72,8 +73,9 @@ def split_spgispeech_train():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
split_spgispeech_train()

View File

@ -38,7 +38,9 @@ def compute_fbank_yesno():
"test",
)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, output_dir=src_dir
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix="yesno",
)
assert manifests is not None

View File

@ -18,8 +18,9 @@
import random
from typing import List, Optional, Tuple
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
@ -90,8 +91,6 @@ def get_tensor_stats(
return x, count
@dataclass
class TensorAndCount:
tensor: Tensor
@ -108,12 +107,12 @@ class TensorDiagnostic(object):
name:
The tensor name.
"""
def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.name = name
self.opts = opts
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
# the keys into self.stats[dim] are strings, whose values can be
# "abs", "value", "positive", "rms", "value".
@ -125,7 +124,6 @@ class TensorDiagnostic(object):
# only adding a new element to the list if there was a different dim.
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
def accumulate(self, x):
"""Accumulate tensors."""
if isinstance(x, Tuple):
@ -137,7 +135,7 @@ class TensorDiagnostic(object):
x = x.unsqueeze(0)
ndim = x.ndim
if self.stats is None:
self.stats = [ dict() for _ in range(ndim) ]
self.stats = [dict() for _ in range(ndim)]
for dim in range(ndim):
this_dim_stats = self.stats[dim]
@ -147,10 +145,10 @@ class TensorDiagnostic(object):
stats_types.append("eigs")
else:
stats_types = ["value", "abs"]
this_dict = self.stats[dim]
for stats_type in stats_types:
stats, count = get_tensor_stats(x, dim, stats_type)
if not stats_type in this_dim_stats:
if stats_type not in this_dim_stats:
this_dim_stats[stats_type] = [] # list of TensorAndCount
done = False
@ -166,13 +164,17 @@ class TensorDiagnostic(object):
done = True
break
if not done:
if this_dim_stats[stats_type] != [] and stats_type == "eigs":
if (
this_dim_stats[stats_type] != []
and stats_type == "eigs"
):
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
# don't accumulat "eigs" stats type, it uses too much memory
this_dim_stats[stats_type] = None
else:
this_dim_stats[stats_type].append(TensorAndCount(stats, count))
this_dim_stats[stats_type].append(
TensorAndCount(stats, count)
)
def print_diagnostics(self):
"""Print diagnostics for each dimension of the tensor."""
@ -191,14 +193,18 @@ class TensorDiagnostic(object):
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
print(
"Error getting eigenvalues, trying another method."
)
eigs = torch.linalg.eigvals(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance
elif len(stats_list) == 1:
stats = stats_list[0].tensor / stats_list[0].count
else:
stats = torch.cat([x.tensor / x.count for x in stats_list], dim=0)
stats = torch.cat(
[x.tensor / x.count for x in stats_list], dim=0
)
if stats_type == "rms":
# we stored the square; after aggregation we need to take sqrt.
@ -206,7 +212,9 @@ class TensorDiagnostic(object):
# if `summarize` we print percentiles of the stats; else,
# we print out individual elements.
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(stats.numel())
summarize = (
len(stats_list) > 1
) or self.opts.dim_is_summarized(stats.numel())
if summarize: # usually `summarize` will be true
# print out percentiles.
stats = stats.sort()[0]
@ -238,9 +246,14 @@ class TensorDiagnostic(object):
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
sizes = [x.tensor.shape[0] for x in stats_list]
size_str = f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
print(f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}")
size_str = (
f"{sizes[0]}"
if len(sizes) == 1
else f"{min(sizes)}..{max(sizes)}"
)
print(
f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}"
)
class ModelDiagnostic(object):