Merge branch 'k2-fsa:master' into wenetspeech-pruned-transducer-stateless-pinyin

This commit is contained in:
Mingshuang Luo 2022-04-08 17:25:03 +08:00 committed by GitHub
commit 5242275b8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 252 additions and 240 deletions

View File

@ -14,4 +14,5 @@ per-file-ignores =
exclude = exclude =
.git, .git,
**/data/**, **/data/**,
icefall/shared/make_kn_lm.py icefall/shared/make_kn_lm.py,
icefall/__init__.py

View File

@ -45,7 +45,9 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
# See https://github.com/psf/black/issues/2964
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
- name: Run flake8 - name: Run flake8
shell: bash shell: bash

View File

@ -27,9 +27,21 @@ Installation
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and ``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
`lhotse <https://github.com/lhotse-speech/lhotse>`_. `lhotse <https://github.com/lhotse-speech/lhotse>`_.
We recommend you to install ``k2`` first, as ``k2`` is bound to We recommend you to use the following steps to install the dependencies.
a specific version of PyTorch after compilation. Install ``k2`` also
installs its dependency PyTorch, which can be reused by ``lhotse``. - (0) Install PyTorch and torchaudio
- (1) Install k2
- (2) Install lhotse
.. caution::
Installation order matters.
(0) Install PyTorch and torchaudio
----------------------------------
Please refer `<https://pytorch.org/>`_ to install PyTorch
and torchaudio.
(1) Install k2 (1) Install k2
@ -54,14 +66,15 @@ to install ``k2``.
Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_ Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_
to install ``lhotse``. to install ``lhotse``.
.. HINT::
Install ``lhotse`` also installs its dependency `torchaudio <https://github.com/pytorch/audio>`_. .. hint::
.. CAUTION:: We strongly recommend you to use::
pip install git+https://github.com/lhotse-speech/lhotse
to install the latest version of lhotse.
If you have installed ``torchaudio``, please consider uninstalling it before
installing ``lhotse``. Otherwise, it may update your already installed PyTorch.
(3) Download icefall (3) Download icefall
-------------------- --------------------

View File

@ -1,98 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class LabelSmoothingLoss(torch.nn.Module):
"""
Implement the LabelSmoothingLoss proposed in the following paper
https://arxiv.org/pdf/1512.00567.pdf
(Rethinking the Inception Architecture for Computer Vision)
"""
def __init__(
self,
ignore_index: int = -1,
label_smoothing: float = 0.1,
reduction: str = "sum",
) -> None:
"""
Args:
ignore_index:
ignored class id
label_smoothing:
smoothing rate (0.0 means the conventional cross entropy loss)
reduction:
It has the same meaning as the reduction in
`torch.nn.CrossEntropyLoss`. It can be one of the following three
values: (1) "none": No reduction will be applied. (2) "mean": the
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.ignore_index of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.ndim == 3
assert target.ndim == 2
assert x.shape[:2] == target.shape
num_classes = x.size(-1)
x = x.reshape(-1, num_classes)
# Now x is of shape (N*T, C)
# We don't want to change target in-place below,
# so we make a copy of it here
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
target[ignored] = 0
true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes
).to(x)
true_dist = (
true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
true_dist[ignored] = 0
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":
return loss.sum()
elif self.reduction == "mean":
return loss.sum() / (~ignored).sum()
else:
return loss.sum(dim=-1)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/label_smoothing.py

View File

@ -1,98 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class LabelSmoothingLoss(torch.nn.Module):
"""
Implement the LabelSmoothingLoss proposed in the following paper
https://arxiv.org/pdf/1512.00567.pdf
(Rethinking the Inception Architecture for Computer Vision)
"""
def __init__(
self,
ignore_index: int = -1,
label_smoothing: float = 0.1,
reduction: str = "sum",
) -> None:
"""
Args:
ignore_index:
ignored class id
label_smoothing:
smoothing rate (0.0 means the conventional cross entropy loss)
reduction:
It has the same meaning as the reduction in
`torch.nn.CrossEntropyLoss`. It can be one of the following three
values: (1) "none": No reduction will be applied. (2) "mean": the
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.ignore_index of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.ndim == 3
assert target.ndim == 2
assert x.shape[:2] == target.shape
num_classes = x.size(-1)
x = x.reshape(-1, num_classes)
# Now x is of shape (N*T, C)
# We don't want to change target in-place below,
# so we make a copy of it here
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
target[ignored] = 0
true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes
).to(x)
true_dist = (
true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
true_dist[ignored] = 0
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":
return loss.sum()
elif self.reduction == "mean":
return loss.sum() / (~ignored).sum()
else:
return loss.sum(dim=-1)

View File

@ -0,0 +1 @@
../conformer_ctc/label_smoothing.py

View File

@ -70,7 +70,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# |-- lexicon.txt # |-- lexicon.txt
# `-- speaker.info # `-- speaker.info
if [ ! -d $dl_dir/aishell/data_aishell/wav ]; then if [ ! -d $dl_dir/aishell/data_aishell/wav/train ]; then
lhotse download aishell $dl_dir lhotse download aishell $dl_dir
fi fi

View File

@ -76,7 +76,11 @@ class LabelSmoothingLoss(torch.nn.Module):
target = target.clone().reshape(-1) target = target.clone().reshape(-1)
ignored = target == self.ignore_index ignored = target == self.ignore_index
target[ignored] = 0
# See https://github.com/k2-fsa/icefall/issues/240
# and https://github.com/k2-fsa/icefall/issues/297
# for why we don't use target[ignored] = 0 here
target = torch.where(ignored, torch.zeros_like(target), target)
true_dist = torch.nn.functional.one_hot( true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes target, num_classes=num_classes
@ -86,8 +90,17 @@ class LabelSmoothingLoss(torch.nn.Module):
true_dist * (1 - self.label_smoothing) true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes + self.label_smoothing / num_classes
) )
# Set the value of ignored indexes to 0 # Set the value of ignored indexes to 0
true_dist[ignored] = 0 #
# See https://github.com/k2-fsa/icefall/issues/240
# and https://github.com/k2-fsa/icefall/issues/297
# for why we don't use true_dist[ignored] = 0 here
true_dist = torch.where(
ignored.unsqueeze(1).repeat(1, true_dist.shape[1]),
torch.zeros_like(true_dist),
true_dist,
)
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum": if self.reduction == "sum":

View File

@ -98,27 +98,28 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help="It specifies the checkpoint to use for decoding." help="""It specifies the checkpoint to use for decoding.
"Note: Epoch counts from 0.", Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
) )
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch' and '--iter'",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
) )
parser.add_argument( parser.add_argument(
@ -453,13 +454,19 @@ def main():
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -485,8 +492,20 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
if params.avg_last_n > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))

View File

@ -33,6 +33,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
@ -496,7 +497,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -23,6 +23,7 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
@ -34,11 +35,20 @@ from lhotse.dataset import (
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriSpeechAsrDataModule: class LibriSpeechAsrDataModule:
""" """
DataModule for k2 ASR experiments. DataModule for k2 ASR experiments.
@ -301,12 +311,18 @@ class LibriSpeechAsrDataModule:
logging.info("Loading sampler state dict") logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict) train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader( train_dl = DataLoader(
train, train,
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=False,
worker_init_fn=worker_init_fn,
) )
return train_dl return train_dl

View File

@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -393,7 +394,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -35,6 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2"
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -397,7 +398,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -109,8 +109,11 @@ class Conformer(Transformer):
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4! with warnings.catch_warnings():
lengths = ((x_lens - 1) // 2 - 1) // 2 warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) mask = make_pad_mask(lengths)

View File

@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -419,7 +420,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -22,6 +22,7 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
@ -34,11 +35,20 @@ from lhotse.dataset.input_strategies import (
OnTheFlyFeatures, OnTheFlyFeatures,
PrecomputedFeatures, PrecomputedFeatures,
) )
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AsrDataModule: class AsrDataModule:
def __init__(self, args: argparse.Namespace): def __init__(self, args: argparse.Namespace):
self.args = args self.args = args
@ -253,12 +263,19 @@ class AsrDataModule:
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader( train_dl = DataLoader(
train, train,
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=False,
worker_init_fn=worker_init_fn,
) )
return train_dl return train_dl

View File

@ -58,6 +58,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import random import random
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -466,7 +467,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -0,0 +1,55 @@
from .checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
remove_checkpoints,
save_checkpoint,
save_checkpoint_with_global_batch_idx,
)
from .decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from .dist import (
cleanup_dist,
setup_dist,
)
from .env import (
get_env_info,
get_git_branch_name,
get_git_date,
get_git_sha1,
)
from .utils import (
AttributeDict,
MetricsTracker,
add_eos,
add_sos,
concat,
encode_supervisions,
get_alignments,
get_executor,
get_texts,
l1_norm,
l2_norm,
linf_norm,
load_alignments,
make_pad_mask,
measure_gradient_norms,
measure_weight_norms,
optim_step_and_measure_param_change,
save_alignments,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)

View File

@ -216,27 +216,62 @@ def save_checkpoint_with_global_batch_idx(
) )
def find_checkpoints(out_dir: Path) -> List[str]: def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
"""Find all available checkpoints in a directory. """Find all available checkpoints in a directory.
The checkpoint filenames have the form: `checkpoint-xxx.pt` The checkpoint filenames have the form: `checkpoint-xxx.pt`
where xxx is a numerical value. where xxx is a numerical value.
Assume you have the following checkpoints in the folder `foo`:
- checkpoint-1.pt
- checkpoint-20.pt
- checkpoint-300.pt
- checkpoint-4000.pt
Case 1 (Return all checkpoints)::
find_checkpoints(out_dir='foo')
Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e.,
checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt)
find_checkpoints(out_dir='foo', iteration=20)
Case 3 (Return checkpoints older than checkpoint-20.pt, i.e.,
checkpoint-20.pt, checkpoint-1.pt)::
find_checkpoints(out_dir='foo', iteration=-20)
Args: Args:
out_dir: out_dir:
The directory where to search for checkpoints. The directory where to search for checkpoints.
iteration:
If it is 0, return all available checkpoints.
If it is positive, return the checkpoints whose iteration number is
greater than or equal to `iteration`.
If it is negative, return the checkpoints whose iteration number is
less than or equal to `-iteration`.
Returns: Returns:
Return a list of checkpoint filenames, sorted in descending Return a list of checkpoint filenames, sorted in descending
order by the numerical value in the filename. order by the numerical value in the filename.
""" """
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
pattern = re.compile(r"checkpoint-([0-9]+).pt") pattern = re.compile(r"checkpoint-([0-9]+).pt")
idx_checkpoints = [ iter_checkpoints = [
(int(pattern.search(c).group(1)), c) for c in checkpoints (int(pattern.search(c).group(1)), c) for c in checkpoints
] ]
# iter_checkpoints is a list of tuples. Each tuple contains
# two elements: (iteration_number, checkpoint-iteration_number.pt)
iter_checkpoints = sorted(
iter_checkpoints, reverse=True, key=lambda x: x[0]
)
if iteration >= 0:
ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
else:
ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration]
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
ans = [ic[1] for ic in idx_checkpoints]
return ans return ans

View File

@ -135,8 +135,13 @@ def get_diagnostics_for_dim(
return "" return ""
count = sum(counts) count = sum(counts)
stats = stats / count stats = stats / count
stats, _ = torch.symeig(stats) try:
stats = stats.abs().sqrt() eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
eigs = torch.linalg.eigvals(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance # sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same: elif sizes_same:
stats = torch.stack(stats).sum(dim=0) stats = torch.stack(stats).sum(dim=0)

View File

@ -95,6 +95,7 @@ def get_env_info() -> Dict[str, Any]:
"k2-git-sha1": k2.version.__git_sha1__, "k2-git-sha1": k2.version.__git_sha1__,
"k2-git-date": k2.version.__git_date__, "k2-git-date": k2.version.__git_date__,
"lhotse-version": lhotse.__version__, "lhotse-version": lhotse.__version__,
"torch-version": torch.__version__,
"torch-cuda-available": torch.cuda.is_available(), "torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda, "torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3], "python-version": sys.version[:3],

View File

@ -1,5 +1,6 @@
[tool.isort] [tool.isort]
profile = "black" profile = "black"
skip = ["icefall/__init__.py"]
[tool.black] [tool.black]
line-length = 80 line-length = 80