mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
[Ready to merge] Do some coding style checks for the latest files (#379)
* style check * do changes for .flake8 * a change for compute_fbank_yesno.py
This commit is contained in:
parent
2900ed8f8f
commit
ec5a112831
14
.flake8
14
.flake8
@ -4,15 +4,11 @@ statistics=true
|
|||||||
max-line-length = 80
|
max-line-length = 80
|
||||||
per-file-ignores =
|
per-file-ignores =
|
||||||
# line too long
|
# line too long
|
||||||
egs/librispeech/ASR/*/conformer.py: E501,
|
icefall/diagnostics.py: E501
|
||||||
egs/aishell/ASR/*/conformer.py: E501,
|
egs/*/ASR/*/conformer.py: E501,
|
||||||
egs/tedlium3/ASR/*/conformer.py: E501,
|
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
|
||||||
egs/gigaspeech/ASR/*/conformer.py: E501,
|
egs/*/ASR/*/optim.py: E501,
|
||||||
egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501,
|
egs/*/ASR/*/scaling.py: E501,
|
||||||
egs/gigaspeech/ASR/pruned_transducer_stateless2/*.py: E501,
|
|
||||||
egs/librispeech/ASR/pruned_transducer_stateless4/*.py: E501,
|
|
||||||
egs/librispeech/ASR/*/optim.py: E501,
|
|
||||||
egs/librispeech/ASR/*/scaling.py: E501,
|
|
||||||
|
|
||||||
# invalid escape sequence (cause by tex formular), W605
|
# invalid escape sequence (cause by tex formular), W605
|
||||||
icefall/utils.py: E501, W605
|
icefall/utils.py: E501, W605
|
||||||
|
@ -27,17 +27,15 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import LilcomChunkyWriter, CutSet, combine
|
from lhotse import CutSet, LilcomChunkyWriter, combine
|
||||||
from lhotse.features.kaldifeat import (
|
from lhotse.features.kaldifeat import (
|
||||||
KaldifeatFbank,
|
KaldifeatFbank,
|
||||||
KaldifeatFbankConfig,
|
KaldifeatFbankConfig,
|
||||||
KaldifeatMelOptions,
|
|
||||||
KaldifeatFrameOptions,
|
KaldifeatFrameOptions,
|
||||||
|
KaldifeatMelOptions,
|
||||||
)
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
# Do this outside of main() in case it needs to take effect
|
# Do this outside of main() in case it needs to take effect
|
||||||
@ -82,23 +80,28 @@ def compute_fbank_musan():
|
|||||||
# create chunks of Musan with duration 5 - 10 seconds
|
# create chunks of Musan with duration 5 - 10 seconds
|
||||||
musan_cuts = (
|
musan_cuts = (
|
||||||
CutSet.from_manifests(
|
CutSet.from_manifests(
|
||||||
recordings=combine(part["recordings"] for part in manifests.values())
|
recordings=combine(
|
||||||
|
part["recordings"] for part in manifests.values()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
.cut_into_windows(10.0)
|
.cut_into_windows(10.0)
|
||||||
.filter(lambda c: c.duration > 5)
|
.filter(lambda c: c.duration > 5)
|
||||||
.compute_and_store_features_batch(
|
.compute_and_store_features_batch(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=output_dir / f"feats_musan",
|
storage_path=output_dir / "feats_musan",
|
||||||
manifest_path=src_dir / f"cuts_musan.jsonl.gz",
|
|
||||||
batch_duration=500,
|
batch_duration=500,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
storage_type=LilcomChunkyWriter,
|
storage_type=LilcomChunkyWriter,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.info(f"Saving to {musan_cuts_path}")
|
||||||
|
musan_cuts.to_file(musan_cuts_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
compute_fbank_musan()
|
compute_fbank_musan()
|
||||||
|
@ -25,17 +25,15 @@ The generated fbank features are saved in data/fbank.
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import load_manifest_lazy, LilcomChunkyWriter
|
from lhotse import LilcomChunkyWriter, load_manifest_lazy
|
||||||
from lhotse.features.kaldifeat import (
|
from lhotse.features.kaldifeat import (
|
||||||
KaldifeatFbank,
|
KaldifeatFbank,
|
||||||
KaldifeatFbankConfig,
|
KaldifeatFbankConfig,
|
||||||
KaldifeatMelOptions,
|
|
||||||
KaldifeatFrameOptions,
|
KaldifeatFrameOptions,
|
||||||
|
KaldifeatMelOptions,
|
||||||
)
|
)
|
||||||
from lhotse.manipulation import combine
|
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -97,27 +95,32 @@ def compute_fbank_spgispeech(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.train:
|
if args.train:
|
||||||
logging.info(f"Processing train")
|
logging.info("Processing train")
|
||||||
cut_set = load_manifest_lazy(src_dir / f"cuts_train_raw.jsonl.gz")
|
cut_set = load_manifest_lazy(src_dir / "cuts_train_raw.jsonl.gz")
|
||||||
chunk_size = len(cut_set) // args.num_splits
|
chunk_size = len(cut_set) // args.num_splits
|
||||||
cut_sets = cut_set.split_lazy(
|
cut_sets = cut_set.split_lazy(
|
||||||
output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}",
|
output_dir=src_dir / f"cuts_train_raw_split{args.num_splits}",
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
)
|
)
|
||||||
start = args.start
|
start = args.start
|
||||||
stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
|
stop = (
|
||||||
|
min(args.stop, args.num_splits)
|
||||||
|
if args.stop > 0
|
||||||
|
else args.num_splits
|
||||||
|
)
|
||||||
num_digits = len(str(args.num_splits))
|
num_digits = len(str(args.num_splits))
|
||||||
for i in range(start, stop):
|
for i in range(start, stop):
|
||||||
idx = f"{i + 1}".zfill(num_digits)
|
idx = f"{i + 1}".zfill(num_digits)
|
||||||
|
cuts_train_idx_path = src_dir / f"cuts_train_{idx}.jsonl.gz"
|
||||||
logging.info(f"Processing train split {i}")
|
logging.info(f"Processing train split {i}")
|
||||||
cs = cut_sets[i].compute_and_store_features_batch(
|
cs = cut_sets[i].compute_and_store_features_batch(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=output_dir / f"feats_train_{idx}",
|
storage_path=output_dir / f"feats_train_{idx}",
|
||||||
manifest_path=src_dir / f"cuts_train_{idx}.jsonl.gz",
|
|
||||||
batch_duration=500,
|
batch_duration=500,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
storage_type=LilcomChunkyWriter,
|
storage_type=LilcomChunkyWriter,
|
||||||
)
|
)
|
||||||
|
cs.to_file(cuts_train_idx_path)
|
||||||
|
|
||||||
if args.test:
|
if args.test:
|
||||||
for partition in ["dev", "val"]:
|
for partition in ["dev", "val"]:
|
||||||
@ -125,7 +128,9 @@ def compute_fbank_spgispeech(args):
|
|||||||
logging.info(f"{partition} already exists - skipping.")
|
logging.info(f"{partition} already exists - skipping.")
|
||||||
continue
|
continue
|
||||||
logging.info(f"Processing {partition}")
|
logging.info(f"Processing {partition}")
|
||||||
cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
|
cut_set = load_manifest_lazy(
|
||||||
|
src_dir / f"cuts_{partition}_raw.jsonl.gz"
|
||||||
|
)
|
||||||
cut_set = cut_set.compute_and_store_features_batch(
|
cut_set = cut_set.compute_and_store_features_batch(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=output_dir / f"feats_{partition}",
|
storage_path=output_dir / f"feats_{partition}",
|
||||||
@ -137,8 +142,9 @@ def compute_fbank_spgispeech(args):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
@ -24,7 +24,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
|
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
@ -56,7 +55,9 @@ def split_spgispeech_train():
|
|||||||
|
|
||||||
# Add speed perturbation
|
# Add speed perturbation
|
||||||
train_cuts = (
|
train_cuts = (
|
||||||
train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
|
train_cuts
|
||||||
|
+ train_cuts.perturb_speed(0.9)
|
||||||
|
+ train_cuts.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Write the manifests to disk.
|
# Write the manifests to disk.
|
||||||
@ -72,8 +73,9 @@ def split_spgispeech_train():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
split_spgispeech_train()
|
split_spgispeech_train()
|
||||||
|
@ -38,7 +38,9 @@ def compute_fbank_yesno():
|
|||||||
"test",
|
"test",
|
||||||
)
|
)
|
||||||
manifests = read_manifests_if_cached(
|
manifests = read_manifests_if_cached(
|
||||||
dataset_parts=dataset_parts, output_dir=src_dir
|
dataset_parts=dataset_parts,
|
||||||
|
output_dir=src_dir,
|
||||||
|
prefix="yesno",
|
||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
|
@ -18,8 +18,9 @@
|
|||||||
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -90,8 +91,6 @@ def get_tensor_stats(
|
|||||||
return x, count
|
return x, count
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TensorAndCount:
|
class TensorAndCount:
|
||||||
tensor: Tensor
|
tensor: Tensor
|
||||||
@ -108,12 +107,12 @@ class TensorDiagnostic(object):
|
|||||||
name:
|
name:
|
||||||
The tensor name.
|
The tensor name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
def __init__(self, opts: TensorDiagnosticOptions, name: str):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.opts = opts
|
self.opts = opts
|
||||||
|
|
||||||
|
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
|
||||||
self.stats = None # we'll later assign a list to this data member. It's a list of dict.
|
|
||||||
|
|
||||||
# the keys into self.stats[dim] are strings, whose values can be
|
# the keys into self.stats[dim] are strings, whose values can be
|
||||||
# "abs", "value", "positive", "rms", "value".
|
# "abs", "value", "positive", "rms", "value".
|
||||||
@ -125,7 +124,6 @@ class TensorDiagnostic(object):
|
|||||||
# only adding a new element to the list if there was a different dim.
|
# only adding a new element to the list if there was a different dim.
|
||||||
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
|
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
|
||||||
|
|
||||||
|
|
||||||
def accumulate(self, x):
|
def accumulate(self, x):
|
||||||
"""Accumulate tensors."""
|
"""Accumulate tensors."""
|
||||||
if isinstance(x, Tuple):
|
if isinstance(x, Tuple):
|
||||||
@ -137,7 +135,7 @@ class TensorDiagnostic(object):
|
|||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
ndim = x.ndim
|
ndim = x.ndim
|
||||||
if self.stats is None:
|
if self.stats is None:
|
||||||
self.stats = [ dict() for _ in range(ndim) ]
|
self.stats = [dict() for _ in range(ndim)]
|
||||||
|
|
||||||
for dim in range(ndim):
|
for dim in range(ndim):
|
||||||
this_dim_stats = self.stats[dim]
|
this_dim_stats = self.stats[dim]
|
||||||
@ -147,10 +145,10 @@ class TensorDiagnostic(object):
|
|||||||
stats_types.append("eigs")
|
stats_types.append("eigs")
|
||||||
else:
|
else:
|
||||||
stats_types = ["value", "abs"]
|
stats_types = ["value", "abs"]
|
||||||
this_dict = self.stats[dim]
|
|
||||||
for stats_type in stats_types:
|
for stats_type in stats_types:
|
||||||
stats, count = get_tensor_stats(x, dim, stats_type)
|
stats, count = get_tensor_stats(x, dim, stats_type)
|
||||||
if not stats_type in this_dim_stats:
|
if stats_type not in this_dim_stats:
|
||||||
this_dim_stats[stats_type] = [] # list of TensorAndCount
|
this_dim_stats[stats_type] = [] # list of TensorAndCount
|
||||||
|
|
||||||
done = False
|
done = False
|
||||||
@ -166,13 +164,17 @@ class TensorDiagnostic(object):
|
|||||||
done = True
|
done = True
|
||||||
break
|
break
|
||||||
if not done:
|
if not done:
|
||||||
if this_dim_stats[stats_type] != [] and stats_type == "eigs":
|
if (
|
||||||
|
this_dim_stats[stats_type] != []
|
||||||
|
and stats_type == "eigs"
|
||||||
|
):
|
||||||
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
|
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
|
||||||
# don't accumulat "eigs" stats type, it uses too much memory
|
# don't accumulat "eigs" stats type, it uses too much memory
|
||||||
this_dim_stats[stats_type] = None
|
this_dim_stats[stats_type] = None
|
||||||
else:
|
else:
|
||||||
this_dim_stats[stats_type].append(TensorAndCount(stats, count))
|
this_dim_stats[stats_type].append(
|
||||||
|
TensorAndCount(stats, count)
|
||||||
|
)
|
||||||
|
|
||||||
def print_diagnostics(self):
|
def print_diagnostics(self):
|
||||||
"""Print diagnostics for each dimension of the tensor."""
|
"""Print diagnostics for each dimension of the tensor."""
|
||||||
@ -191,14 +193,18 @@ class TensorDiagnostic(object):
|
|||||||
eigs, _ = torch.symeig(stats)
|
eigs, _ = torch.symeig(stats)
|
||||||
stats = eigs.abs().sqrt()
|
stats = eigs.abs().sqrt()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
print("Error getting eigenvalues, trying another method.")
|
print(
|
||||||
|
"Error getting eigenvalues, trying another method."
|
||||||
|
)
|
||||||
eigs = torch.linalg.eigvals(stats)
|
eigs = torch.linalg.eigvals(stats)
|
||||||
stats = eigs.abs().sqrt()
|
stats = eigs.abs().sqrt()
|
||||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||||
elif len(stats_list) == 1:
|
elif len(stats_list) == 1:
|
||||||
stats = stats_list[0].tensor / stats_list[0].count
|
stats = stats_list[0].tensor / stats_list[0].count
|
||||||
else:
|
else:
|
||||||
stats = torch.cat([x.tensor / x.count for x in stats_list], dim=0)
|
stats = torch.cat(
|
||||||
|
[x.tensor / x.count for x in stats_list], dim=0
|
||||||
|
)
|
||||||
|
|
||||||
if stats_type == "rms":
|
if stats_type == "rms":
|
||||||
# we stored the square; after aggregation we need to take sqrt.
|
# we stored the square; after aggregation we need to take sqrt.
|
||||||
@ -206,7 +212,9 @@ class TensorDiagnostic(object):
|
|||||||
|
|
||||||
# if `summarize` we print percentiles of the stats; else,
|
# if `summarize` we print percentiles of the stats; else,
|
||||||
# we print out individual elements.
|
# we print out individual elements.
|
||||||
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(stats.numel())
|
summarize = (
|
||||||
|
len(stats_list) > 1
|
||||||
|
) or self.opts.dim_is_summarized(stats.numel())
|
||||||
if summarize: # usually `summarize` will be true
|
if summarize: # usually `summarize` will be true
|
||||||
# print out percentiles.
|
# print out percentiles.
|
||||||
stats = stats.sort()[0]
|
stats = stats.sort()[0]
|
||||||
@ -238,9 +246,14 @@ class TensorDiagnostic(object):
|
|||||||
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
|
# ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
|
||||||
|
|
||||||
sizes = [x.tensor.shape[0] for x in stats_list]
|
sizes = [x.tensor.shape[0] for x in stats_list]
|
||||||
size_str = f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
|
size_str = (
|
||||||
print(f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}")
|
f"{sizes[0]}"
|
||||||
|
if len(sizes) == 1
|
||||||
|
else f"{min(sizes)}..{max(sizes)}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"module={self.name}, dim={dim}, size={size_str}, {stats_type} {ans}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelDiagnostic(object):
|
class ModelDiagnostic(object):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user