[Ready to merge] Do some coding style checks for the latest files (#379)

* style check

* do changes for .flake8

* a change for compute_fbank_yesno.py
This commit is contained in:
Mingshuang Luo 2022-05-20 19:30:38 +08:00 committed by GitHub
parent 2900ed8f8f
commit ec5a112831
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 74 additions and 52 deletions

14
.flake8
View File

@ -4,15 +4,11 @@ statistics=true
max-line-length = 80 max-line-length = 80
per-file-ignores = per-file-ignores =
# line too long # line too long
egs/librispeech/ASR/*/conformer.py: E501, icefall/diagnostics.py: E501
egs/aishell/ASR/*/conformer.py: E501, egs/*/ASR/*/conformer.py: E501,
egs/tedlium3/ASR/*/conformer.py: E501, egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
egs/gigaspeech/ASR/*/conformer.py: E501, egs/*/ASR/*/optim.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501, egs/*/ASR/*/scaling.py: E501,
egs/gigaspeech/ASR/pruned_transducer_stateless2/*.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless4/*.py: E501,
egs/librispeech/ASR/*/optim.py: E501,
egs/librispeech/ASR/*/scaling.py: E501,
# invalid escape sequence (cause by tex formular), W605 # invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605 icefall/utils.py: E501, W605

View File

@ -27,17 +27,15 @@ import logging
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import LilcomChunkyWriter, CutSet, combine from lhotse import CutSet, LilcomChunkyWriter, combine
from lhotse.features.kaldifeat import ( from lhotse.features.kaldifeat import (
KaldifeatFbank, KaldifeatFbank,
KaldifeatFbankConfig, KaldifeatFbankConfig,
KaldifeatMelOptions,
KaldifeatFrameOptions, KaldifeatFrameOptions,
KaldifeatMelOptions,
) )
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or # Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down. # it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect # Do this outside of main() in case it needs to take effect
@ -82,23 +80,28 @@ def compute_fbank_musan():
# create chunks of Musan with duration 5 - 10 seconds # create chunks of Musan with duration 5 - 10 seconds
musan_cuts = ( musan_cuts = (
CutSet.from_manifests( CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values()) recordings=combine(
part["recordings"] for part in manifests.values()
)
) )
.cut_into_windows(10.0) .cut_into_windows(10.0)
.filter(lambda c: c.duration > 5) .filter(lambda c: c.duration > 5)
.compute_and_store_features_batch( .compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=output_dir / f"feats_musan", storage_path=output_dir / "feats_musan",
manifest_path=src_dir / f"cuts_musan.jsonl.gz",
batch_duration=500, batch_duration=500,
num_workers=4, num_workers=4,
storage_type=LilcomChunkyWriter, storage_type=LilcomChunkyWriter,
) )
) )
logging.info(f"Saving to {musan_cuts_path}")
musan_cuts.to_file(musan_cuts_path)
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan() compute_fbank_musan()

View File

@ -25,17 +25,15 @@ The generated fbank features are saved in data/fbank.
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from tqdm import tqdm
import torch import torch
from lhotse import load_manifest_lazy, LilcomChunkyWriter from lhotse import LilcomChunkyWriter, load_manifest_lazy
from lhotse.features.kaldifeat import ( from lhotse.features.kaldifeat import (
KaldifeatFbank, KaldifeatFbank,
KaldifeatFbankConfig, KaldifeatFbankConfig,
KaldifeatMelOptions,
KaldifeatFrameOptions, KaldifeatFrameOptions,
KaldifeatMelOptions,
) )
from lhotse.manipulation import combine
# Torch's multithreaded behavior needs to be disabled or # Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down. # it wastes a lot of CPU and slow things down.
@ -97,27 +95,32 @@ def compute_fbank_spgispeech(args):
) )
if args.train: if args.train:
logging.info(f"Processing train") logging.info("Processing train")
cut_set = load_manifest_lazy(src_dir / f"cuts_train_raw.jsonl.gz") cut_set = load_manifest_lazy(src_dir / "cuts_train_raw.jsonl.gz")
chunk_size = len(cut_set) // args.num_splits chunk_size = len(cut_set) // args.num_splits
cut_sets = cut_set.split_lazy( cut_sets = cut_set.split_lazy(
output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}", output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}",
chunk_size=chunk_size, chunk_size=chunk_size,
) )
start = args.start start = args.start
stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits stop = (
min(args.stop, args.num_splits)
if args.stop > 0
else args.num_splits
)
num_digits = len(str(args.num_splits)) num_digits = len(str(args.num_splits))
for i in range(start, stop): for i in range(start, stop):
idx = f"{i + 1}".zfill(num_digits) idx = f"{i + 1}".zfill(num_digits)
cuts_train_idx_path = src_dir / f"cuts_train_{idx}.jsonl.gz"
logging.info(f"Processing train split {i}") logging.info(f"Processing train split {i}")
cs = cut_sets[i].compute_and_store_features_batch( cs = cut_sets[i].compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=output_dir / f"feats_train_{idx}", storage_path=output_dir / f"feats_train_{idx}",
manifest_path=src_dir / f"cuts_train_{idx}.jsonl.gz",
batch_duration=500, batch_duration=500,
num_workers=4, num_workers=4,
storage_type=LilcomChunkyWriter, storage_type=LilcomChunkyWriter,
) )
cs.to_file(cuts_train_idx_path)
if args.test: if args.test:
for partition in ["dev", "val"]: for partition in ["dev", "val"]:
@ -125,7 +128,9 @@ def compute_fbank_spgispeech(args):
logging.info(f"{partition} already exists - skipping.") logging.info(f"{partition} already exists - skipping.")
continue continue
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz") cut_set = load_manifest_lazy(
src_dir / f"cuts_{partition}_raw.jsonl.gz"
)
cut_set = cut_set.compute_and_store_features_batch( cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=output_dir / f"feats_{partition}", storage_path=output_dir / f"feats_{partition}",
@ -137,8 +142,9 @@ def compute_fbank_spgispeech(args):
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args() args = get_args()

View File

@ -24,7 +24,6 @@ from pathlib import Path
import torch import torch
from lhotse import CutSet from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
# Torch's multithreaded behavior needs to be disabled or # Torch's multithreaded behavior needs to be disabled or
@ -56,7 +55,9 @@ def split_spgispeech_train():
# Add speed perturbation # Add speed perturbation
train_cuts = ( train_cuts = (
train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1) train_cuts
+ train_cuts.perturb_speed(0.9)
+ train_cuts.perturb_speed(1.1)
) )
# Write the manifests to disk. # Write the manifests to disk.
@ -72,8 +73,9 @@ def split_spgispeech_train():
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
split_spgispeech_train() split_spgispeech_train()

View File

@ -38,7 +38,9 @@ def compute_fbank_yesno():
"test", "test",
) )
manifests = read_manifests_if_cached( manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, output_dir=src_dir dataset_parts=dataset_parts,
output_dir=src_dir,
prefix="yesno",
) )
assert manifests is not None assert manifests is not None

View File

@ -18,8 +18,9 @@
import random import random
from typing import List, Optional, Tuple
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -90,8 +91,6 @@ def get_tensor_stats(
return x, count return x, count
@dataclass @dataclass
class TensorAndCount: class TensorAndCount:
tensor: Tensor tensor: Tensor
@ -108,12 +107,12 @@ class TensorDiagnostic(object):
name: name:
The tensor name. The tensor name.
""" """
def __init__(self, opts: TensorDiagnosticOptions, name: str): def __init__(self, opts: TensorDiagnosticOptions, name: str):
self.name = name self.name = name
self.opts = opts self.opts = opts
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
# the keys into self.stats[dim] are strings, whose values can be # the keys into self.stats[dim] are strings, whose values can be
# "abs", "value", "positive", "rms", "value". # "abs", "value", "positive", "rms", "value".
@ -125,7 +124,6 @@ class TensorDiagnostic(object):
# only adding a new element to the list if there was a different dim. # only adding a new element to the list if there was a different dim.
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value. # if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
def accumulate(self, x): def accumulate(self, x):
"""Accumulate tensors.""" """Accumulate tensors."""
if isinstance(x, Tuple): if isinstance(x, Tuple):
@ -137,7 +135,7 @@ class TensorDiagnostic(object):
x = x.unsqueeze(0) x = x.unsqueeze(0)
ndim = x.ndim ndim = x.ndim
if self.stats is None: if self.stats is None:
self.stats = [ dict() for _ in range(ndim) ] self.stats = [dict() for _ in range(ndim)]
for dim in range(ndim): for dim in range(ndim):
this_dim_stats = self.stats[dim] this_dim_stats = self.stats[dim]
@ -147,10 +145,10 @@ class TensorDiagnostic(object):
stats_types.append("eigs") stats_types.append("eigs")
else: else:
stats_types = ["value", "abs"] stats_types = ["value", "abs"]
this_dict = self.stats[dim]
for stats_type in stats_types: for stats_type in stats_types:
stats, count = get_tensor_stats(x, dim, stats_type) stats, count = get_tensor_stats(x, dim, stats_type)
if not stats_type in this_dim_stats: if stats_type not in this_dim_stats:
this_dim_stats[stats_type] = [] # list of TensorAndCount this_dim_stats[stats_type] = [] # list of TensorAndCount
done = False done = False
@ -166,13 +164,17 @@ class TensorDiagnostic(object):
done = True done = True
break break
if not done: if not done:
if this_dim_stats[stats_type] != [] and stats_type == "eigs": if (
this_dim_stats[stats_type] != []
and stats_type == "eigs"
):
# >1 size encountered on this dim, e.g. it's a batch or time dimension, # >1 size encountered on this dim, e.g. it's a batch or time dimension,
# don't accumulat "eigs" stats type, it uses too much memory # don't accumulat "eigs" stats type, it uses too much memory
this_dim_stats[stats_type] = None this_dim_stats[stats_type] = None
else: else:
this_dim_stats[stats_type].append(TensorAndCount(stats, count)) this_dim_stats[stats_type].append(
TensorAndCount(stats, count)
)
def print_diagnostics(self): def print_diagnostics(self):
"""Print diagnostics for each dimension of the tensor.""" """Print diagnostics for each dimension of the tensor."""
@ -191,14 +193,18 @@ class TensorDiagnostic(object):
eigs, _ = torch.symeig(stats) eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt() stats = eigs.abs().sqrt()
except: # noqa except: # noqa
print("Error getting eigenvalues, trying another method.") print(
"Error getting eigenvalues, trying another method."
)
eigs = torch.linalg.eigvals(stats) eigs = torch.linalg.eigvals(stats)
stats = eigs.abs().sqrt() stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance # sqrt so it reflects data magnitude, like stddev- not variance
elif len(stats_list) == 1: elif len(stats_list) == 1:
stats = stats_list[0].tensor / stats_list[0].count stats = stats_list[0].tensor / stats_list[0].count
else: else:
stats = torch.cat([x.tensor / x.count for x in stats_list], dim=0) stats = torch.cat(
[x.tensor / x.count for x in stats_list], dim=0
)
if stats_type == "rms": if stats_type == "rms":
# we stored the square; after aggregation we need to take sqrt. # we stored the square; after aggregation we need to take sqrt.
@ -206,7 +212,9 @@ class TensorDiagnostic(object):
# if `summarize` we print percentiles of the stats; else, # if `summarize` we print percentiles of the stats; else,
# we print out individual elements. # we print out individual elements.
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(stats.numel()) summarize = (
len(stats_list) > 1
) or self.opts.dim_is_summarized(stats.numel())
if summarize: # usually `summarize` will be true if summarize: # usually `summarize` will be true
# print out percentiles. # print out percentiles.
stats = stats.sort()[0] stats = stats.sort()[0]
@ -238,9 +246,14 @@ class TensorDiagnostic(object):
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5" # ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
sizes = [x.tensor.shape[0] for x in stats_list] sizes = [x.tensor.shape[0] for x in stats_list]
size_str = f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" size_str = (
print(f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}") f"{sizes[0]}"
if len(sizes) == 1
else f"{min(sizes)}..{max(sizes)}"
)
print(
f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}"
)
class ModelDiagnostic(object): class ModelDiagnostic(object):