mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 06:04:18 +00:00
Merge branch 'master' of https://github.com/k2-fsa/icefall into spgi
This commit is contained in:
commit
397216eb78
4
.github/workflows/style_check.yml
vendored
4
.github/workflows/style_check.yml
vendored
@ -45,7 +45,9 @@ jobs:
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2
|
||||
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
|
||||
# See https://github.com/psf/black/issues/2964
|
||||
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
|
||||
|
||||
- name: Run flake8
|
||||
shell: bash
|
||||
|
@ -27,9 +27,21 @@ Installation
|
||||
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
|
||||
`lhotse <https://github.com/lhotse-speech/lhotse>`_.
|
||||
|
||||
We recommend you to install ``k2`` first, as ``k2`` is bound to
|
||||
a specific version of PyTorch after compilation. Install ``k2`` also
|
||||
installs its dependency PyTorch, which can be reused by ``lhotse``.
|
||||
We recommend you to use the following steps to install the dependencies.
|
||||
|
||||
- (0) Install PyTorch and torchaudio
|
||||
- (1) Install k2
|
||||
- (2) Install lhotse
|
||||
|
||||
.. caution::
|
||||
|
||||
Installation order matters.
|
||||
|
||||
(0) Install PyTorch and torchaudio
|
||||
----------------------------------
|
||||
|
||||
Please refer `<https://pytorch.org/>`_ to install PyTorch
|
||||
and torchaudio.
|
||||
|
||||
|
||||
(1) Install k2
|
||||
@ -54,14 +66,15 @@ to install ``k2``.
|
||||
Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_
|
||||
to install ``lhotse``.
|
||||
|
||||
.. HINT::
|
||||
|
||||
Install ``lhotse`` also installs its dependency `torchaudio <https://github.com/pytorch/audio>`_.
|
||||
.. hint::
|
||||
|
||||
.. CAUTION::
|
||||
We strongly recommend you to use::
|
||||
|
||||
pip install git+https://github.com/lhotse-speech/lhotse
|
||||
|
||||
to install the latest version of lhotse.
|
||||
|
||||
If you have installed ``torchaudio``, please consider uninstalling it before
|
||||
installing ``lhotse``. Otherwise, it may update your already installed PyTorch.
|
||||
|
||||
(3) Download icefall
|
||||
--------------------
|
||||
|
@ -33,6 +33,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
@ -496,7 +497,11 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
@ -23,6 +23,7 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
@ -34,11 +35,20 @@ from lhotse.dataset import (
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
@ -301,12 +311,18 @@ class LibriSpeechAsrDataModule:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
@ -393,7 +394,11 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
@ -35,6 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2"
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
@ -397,7 +398,11 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
@ -109,8 +109,11 @@ class Conformer(Transformer):
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
|
@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
@ -419,7 +420,11 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
@ -22,6 +22,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
@ -34,11 +35,20 @@ from lhotse.dataset.input_strategies import (
|
||||
OnTheFlyFeatures,
|
||||
PrecomputedFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class AsrDataModule:
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
@ -253,12 +263,19 @@ class AsrDataModule:
|
||||
)
|
||||
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
return train_dl
|
||||
|
||||
|
@ -58,6 +58,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
@ -466,7 +467,11 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
@ -135,8 +135,13 @@ def get_diagnostics_for_dim(
|
||||
return ""
|
||||
count = sum(counts)
|
||||
stats = stats / count
|
||||
stats, _ = torch.symeig(stats)
|
||||
stats = stats.abs().sqrt()
|
||||
try:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
except: # noqa
|
||||
print("Error getting eigenvalues, trying another method.")
|
||||
eigs = torch.linalg.eigvals(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||
elif sizes_same:
|
||||
stats = torch.stack(stats).sum(dim=0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user