mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
[Ready to merge] Do some coding style checks for the latest files (#379)
* style check * do changes for .flake8 * a change for compute_fbank_yesno.py
This commit is contained in:
parent
2900ed8f8f
commit
ec5a112831
14
.flake8
14
.flake8
@ -4,15 +4,11 @@ statistics=true
|
||||
max-line-length = 80
|
||||
per-file-ignores =
|
||||
# line too long
|
||||
egs/librispeech/ASR/*/conformer.py: E501,
|
||||
egs/aishell/ASR/*/conformer.py: E501,
|
||||
egs/tedlium3/ASR/*/conformer.py: E501,
|
||||
egs/gigaspeech/ASR/*/conformer.py: E501,
|
||||
egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501,
|
||||
egs/gigaspeech/ASR/pruned_transducer_stateless2/*.py: E501,
|
||||
egs/librispeech/ASR/pruned_transducer_stateless4/*.py: E501,
|
||||
egs/librispeech/ASR/*/optim.py: E501,
|
||||
egs/librispeech/ASR/*/scaling.py: E501,
|
||||
icefall/diagnostics.py: E501
|
||||
egs/*/ASR/*/conformer.py: E501,
|
||||
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
|
||||
egs/*/ASR/*/optim.py: E501,
|
||||
egs/*/ASR/*/scaling.py: E501,
|
||||
|
||||
# invalid escape sequence (cause by tex formular), W605
|
||||
icefall/utils.py: E501, W605
|
||||
|
@ -27,17 +27,15 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import LilcomChunkyWriter, CutSet, combine
|
||||
from lhotse import CutSet, LilcomChunkyWriter, combine
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatMelOptions,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
@ -82,23 +80,28 @@ def compute_fbank_musan():
|
||||
# create chunks of Musan with duration 5 - 10 seconds
|
||||
musan_cuts = (
|
||||
CutSet.from_manifests(
|
||||
recordings=combine(part["recordings"] for part in manifests.values())
|
||||
recordings=combine(
|
||||
part["recordings"] for part in manifests.values()
|
||||
)
|
||||
)
|
||||
.cut_into_windows(10.0)
|
||||
.filter(lambda c: c.duration > 5)
|
||||
.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"feats_musan",
|
||||
manifest_path=src_dir / f"cuts_musan.jsonl.gz",
|
||||
storage_path=output_dir / "feats_musan",
|
||||
batch_duration=500,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {musan_cuts_path}")
|
||||
musan_cuts.to_file(musan_cuts_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
compute_fbank_musan()
|
||||
|
@ -25,17 +25,15 @@ The generated fbank features are saved in data/fbank.
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from lhotse import load_manifest_lazy, LilcomChunkyWriter
|
||||
from lhotse import LilcomChunkyWriter, load_manifest_lazy
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatMelOptions,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.manipulation import combine
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -97,27 +95,32 @@ def compute_fbank_spgispeech(args):
|
||||
)
|
||||
|
||||
if args.train:
|
||||
logging.info(f"Processing train")
|
||||
cut_set = load_manifest_lazy(src_dir / f"cuts_train_raw.jsonl.gz")
|
||||
logging.info("Processing train")
|
||||
cut_set = load_manifest_lazy(src_dir / "cuts_train_raw.jsonl.gz")
|
||||
chunk_size = len(cut_set) // args.num_splits
|
||||
cut_sets = cut_set.split_lazy(
|
||||
output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}",
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
start = args.start
|
||||
stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
|
||||
stop = (
|
||||
min(args.stop, args.num_splits)
|
||||
if args.stop > 0
|
||||
else args.num_splits
|
||||
)
|
||||
num_digits = len(str(args.num_splits))
|
||||
for i in range(start, stop):
|
||||
idx = f"{i + 1}".zfill(num_digits)
|
||||
cuts_train_idx_path = src_dir / f"cuts_train_{idx}.jsonl.gz"
|
||||
logging.info(f"Processing train split {i}")
|
||||
cs = cut_sets[i].compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"feats_train_{idx}",
|
||||
manifest_path=src_dir / f"cuts_train_{idx}.jsonl.gz",
|
||||
batch_duration=500,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cs.to_file(cuts_train_idx_path)
|
||||
|
||||
if args.test:
|
||||
for partition in ["dev", "val"]:
|
||||
@ -125,7 +128,9 @@ def compute_fbank_spgispeech(args):
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
|
||||
cut_set = load_manifest_lazy(
|
||||
src_dir / f"cuts_{partition}_raw.jsonl.gz"
|
||||
)
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"feats_{partition}",
|
||||
@ -137,8 +142,9 @@ def compute_fbank_spgispeech(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
|
@ -24,7 +24,6 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet
|
||||
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
@ -56,7 +55,9 @@ def split_spgispeech_train():
|
||||
|
||||
# Add speed perturbation
|
||||
train_cuts = (
|
||||
train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
|
||||
train_cuts
|
||||
+ train_cuts.perturb_speed(0.9)
|
||||
+ train_cuts.perturb_speed(1.1)
|
||||
)
|
||||
|
||||
# Write the manifests to disk.
|
||||
@ -72,8 +73,9 @@ def split_spgispeech_train():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
split_spgispeech_train()
|
||||
|
@ -38,7 +38,9 @@ def compute_fbank_yesno():
|
||||
"test",
|
||||
)
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts, output_dir=src_dir
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix="yesno",
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
|
@ -18,8 +18,9 @@
|
||||
|
||||
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -90,8 +91,6 @@ def get_tensor_stats(
|
||||
return x, count
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorAndCount:
|
||||
tensor: Tensor
|
||||
@ -108,12 +107,12 @@ class TensorDiagnostic(object):
|
||||
name:
|
||||
The tensor name.
|
||||
"""
|
||||
|
||||
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
||||
self.name = name
|
||||
self.opts = opts
|
||||
|
||||
|
||||
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
|
||||
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
|
||||
|
||||
# the keys into self.stats[dim] are strings, whose values can be
|
||||
# "abs", "value", "positive", "rms", "value".
|
||||
@ -125,7 +124,6 @@ class TensorDiagnostic(object):
|
||||
# only adding a new element to the list if there was a different dim.
|
||||
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
|
||||
|
||||
|
||||
def accumulate(self, x):
|
||||
"""Accumulate tensors."""
|
||||
if isinstance(x, Tuple):
|
||||
@ -137,7 +135,7 @@ class TensorDiagnostic(object):
|
||||
x = x.unsqueeze(0)
|
||||
ndim = x.ndim
|
||||
if self.stats is None:
|
||||
self.stats = [ dict() for _ in range(ndim) ]
|
||||
self.stats = [dict() for _ in range(ndim)]
|
||||
|
||||
for dim in range(ndim):
|
||||
this_dim_stats = self.stats[dim]
|
||||
@ -147,10 +145,10 @@ class TensorDiagnostic(object):
|
||||
stats_types.append("eigs")
|
||||
else:
|
||||
stats_types = ["value", "abs"]
|
||||
this_dict = self.stats[dim]
|
||||
|
||||
for stats_type in stats_types:
|
||||
stats, count = get_tensor_stats(x, dim, stats_type)
|
||||
if not stats_type in this_dim_stats:
|
||||
if stats_type not in this_dim_stats:
|
||||
this_dim_stats[stats_type] = [] # list of TensorAndCount
|
||||
|
||||
done = False
|
||||
@ -166,13 +164,17 @@ class TensorDiagnostic(object):
|
||||
done = True
|
||||
break
|
||||
if not done:
|
||||
if this_dim_stats[stats_type] != [] and stats_type == "eigs":
|
||||
if (
|
||||
this_dim_stats[stats_type] != []
|
||||
and stats_type == "eigs"
|
||||
):
|
||||
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
|
||||
# don't accumulat "eigs" stats type, it uses too much memory
|
||||
this_dim_stats[stats_type] = None
|
||||
else:
|
||||
this_dim_stats[stats_type].append(TensorAndCount(stats, count))
|
||||
|
||||
this_dim_stats[stats_type].append(
|
||||
TensorAndCount(stats, count)
|
||||
)
|
||||
|
||||
def print_diagnostics(self):
|
||||
"""Print diagnostics for each dimension of the tensor."""
|
||||
@ -191,14 +193,18 @@ class TensorDiagnostic(object):
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
except: # noqa
|
||||
print("Error getting eigenvalues, trying another method.")
|
||||
print(
|
||||
"Error getting eigenvalues, trying another method."
|
||||
)
|
||||
eigs = torch.linalg.eigvals(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||
elif len(stats_list) == 1:
|
||||
stats = stats_list[0].tensor / stats_list[0].count
|
||||
else:
|
||||
stats = torch.cat([x.tensor / x.count for x in stats_list], dim=0)
|
||||
stats = torch.cat(
|
||||
[x.tensor / x.count for x in stats_list], dim=0
|
||||
)
|
||||
|
||||
if stats_type == "rms":
|
||||
# we stored the square; after aggregation we need to take sqrt.
|
||||
@ -206,7 +212,9 @@ class TensorDiagnostic(object):
|
||||
|
||||
# if `summarize` we print percentiles of the stats; else,
|
||||
# we print out individual elements.
|
||||
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(stats.numel())
|
||||
summarize = (
|
||||
len(stats_list) > 1
|
||||
) or self.opts.dim_is_summarized(stats.numel())
|
||||
if summarize: # usually `summarize` will be true
|
||||
# print out percentiles.
|
||||
stats = stats.sort()[0]
|
||||
@ -238,9 +246,14 @@ class TensorDiagnostic(object):
|
||||
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
|
||||
|
||||
sizes = [x.tensor.shape[0] for x in stats_list]
|
||||
size_str = f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
|
||||
print(f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}")
|
||||
|
||||
size_str = (
|
||||
f"{sizes[0]}"
|
||||
if len(sizes) == 1
|
||||
else f"{min(sizes)}..{max(sizes)}"
|
||||
)
|
||||
print(
|
||||
f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}"
|
||||
)
|
||||
|
||||
|
||||
class ModelDiagnostic(object):
|
||||
|
Loading…
x
Reference in New Issue
Block a user