mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge changes from master
This commit is contained in:
commit
eec597fdd5
3
.flake8
3
.flake8
@ -13,4 +13,5 @@ per-file-ignores =
|
||||
exclude =
|
||||
.git,
|
||||
**/data/**,
|
||||
icefall/shared/make_kn_lm.py
|
||||
icefall/shared/make_kn_lm.py,
|
||||
icefall/__init__.py
|
||||
|
4
.github/workflows/style_check.yml
vendored
4
.github/workflows/style_check.yml
vendored
@ -45,7 +45,9 @@ jobs:
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2
|
||||
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
|
||||
# See https://github.com/psf/black/issues/2964
|
||||
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
|
||||
|
||||
- name: Run flake8
|
||||
shell: bash
|
||||
|
@ -27,9 +27,21 @@ Installation
|
||||
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
|
||||
`lhotse <https://github.com/lhotse-speech/lhotse>`_.
|
||||
|
||||
We recommend you to install ``k2`` first, as ``k2`` is bound to
|
||||
a specific version of PyTorch after compilation. Install ``k2`` also
|
||||
installs its dependency PyTorch, which can be reused by ``lhotse``.
|
||||
We recommend you to use the following steps to install the dependencies.
|
||||
|
||||
- (0) Install PyTorch and torchaudio
|
||||
- (1) Install k2
|
||||
- (2) Install lhotse
|
||||
|
||||
.. caution::
|
||||
|
||||
Installation order matters.
|
||||
|
||||
(0) Install PyTorch and torchaudio
|
||||
----------------------------------
|
||||
|
||||
Please refer `<https://pytorch.org/>`_ to install PyTorch
|
||||
and torchaudio.
|
||||
|
||||
|
||||
(1) Install k2
|
||||
@ -54,14 +66,15 @@ to install ``k2``.
|
||||
Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_
|
||||
to install ``lhotse``.
|
||||
|
||||
.. HINT::
|
||||
|
||||
Install ``lhotse`` also installs its dependency `torchaudio <https://github.com/pytorch/audio>`_.
|
||||
.. hint::
|
||||
|
||||
.. CAUTION::
|
||||
We strongly recommend you to use::
|
||||
|
||||
pip install git+https://github.com/lhotse-speech/lhotse
|
||||
|
||||
to install the latest version of lhotse.
|
||||
|
||||
If you have installed ``torchaudio``, please consider uninstalling it before
|
||||
installing ``lhotse``. Otherwise, it may update your already installed PyTorch.
|
||||
|
||||
(3) Download icefall
|
||||
--------------------
|
||||
|
@ -33,6 +33,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
@ -496,7 +497,11 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
@ -113,6 +113,8 @@ class Conformer(EncoderInterface):
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
@ -512,6 +513,8 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
|
@ -25,6 +25,7 @@ from typing import Any, Dict, Optional
|
||||
import torch
|
||||
from lhotse.utils import fix_random_seed
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
@ -36,11 +37,20 @@ from lhotse.dataset import (
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
@ -303,11 +313,10 @@ class LibriSpeechAsrDataModule:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have previously been
|
||||
# set in the main process.
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
def worker_init_fn(worker_id: int):
|
||||
fix_random_seed(seed + worker_id)
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
|
@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
@ -393,7 +394,11 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
@ -35,6 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2"
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
@ -397,7 +398,11 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
@ -107,8 +107,11 @@ class Conformer(Transformer):
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
|
@ -34,7 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import diagnostics # ./diagnostics.py
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
@ -424,7 +424,11 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
@ -22,6 +22,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
@ -34,11 +35,20 @@ from lhotse.dataset.input_strategies import (
|
||||
OnTheFlyFeatures,
|
||||
PrecomputedFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class AsrDataModule:
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
@ -253,12 +263,19 @@ class AsrDataModule:
|
||||
)
|
||||
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
return train_dl
|
||||
|
||||
|
@ -58,6 +58,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
@ -466,7 +467,11 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
@ -0,0 +1,24 @@
|
||||
from .utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
add_eos,
|
||||
add_sos,
|
||||
concat,
|
||||
encode_supervisions,
|
||||
get_alignments,
|
||||
get_executor,
|
||||
get_texts,
|
||||
l1_norm,
|
||||
l2_norm,
|
||||
linf_norm,
|
||||
load_alignments,
|
||||
make_pad_mask,
|
||||
measure_gradient_norms,
|
||||
measure_weight_norms,
|
||||
optim_step_and_measure_param_change,
|
||||
save_alignments,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
@ -138,9 +138,9 @@ def get_diagnostics_for_dim(
|
||||
try:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
except:
|
||||
print("Error getting eigenvalues, trying another method")
|
||||
eigs, _ = torch.eigs(stats)
|
||||
except: # noqa
|
||||
print("Error getting eigenvalues, trying another method.")
|
||||
eigs = torch.linalg.eigvals(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||
elif sizes_same:
|
||||
|
Loading…
x
Reference in New Issue
Block a user