diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c6114ce73..6b72b5a0c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -46,10 +46,18 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install libnsdfile and libsox + if: startsWith(matrix.os, 'ubuntu') + run: | + sudo apt update + sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg + sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all + - name: Install Python dependencies run: | python3 -m pip install --upgrade pip pytest pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + pip install git+https://github.com/lhotse-speech/lhotse # icefall requirements pip install -r requirements.txt @@ -88,4 +96,3 @@ jobs: # runt tests for conformer ctc cd egs/librispeech/ASR/conformer_ctc pytest - diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 43a46a30f..eb679b951 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -38,14 +38,16 @@ python conformer_ctc/train.py --bucketing-sampler True \ --concatenate-cuts False \ --max-duration 200 \ --full-libri True \ - --world-size 4 + --world-size 4 \ + --lang-dir data/lang_bpe_5000 python conformer_ctc/decode.py --nbest-scale 0.5 \ --epoch 34 \ --avg 20 \ --method attention-decoder \ --max-duration 20 \ - --num-paths 100 + --num-paths 100 \ + --lang-dir data/lang_bpe_5000 ``` ### LibriSpeech training results (Tdnn-Lstm) diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md index 23b51167b..164c3e53e 100644 --- a/egs/librispeech/ASR/conformer_ctc/README.md +++ b/egs/librispeech/ASR/conformer_ctc/README.md @@ -1,3 +1,53 @@ +## Introduction + Please visit for how to run this recipe. + +## How to compute framewise alignment information + +### Step 1: Train a model + +Please use `conformer_ctc/train.py` to train a model. +See +for how to do it. + +### Step 2: Compute framewise alignment + +Run + +``` +# Choose a checkpoint and determine the number of checkpoints to average +epoch=30 +avg=15 +./conformer_ctc/ali.py \ + --epoch $epoch \ + --avg $avg \ + --max-duration 500 \ + --bucketing-sampler 0 \ + --full-libri 1 \ + --exp-dir conformer_ctc/exp \ + --lang-dir data/lang_bpe_5000 \ + --ali-dir data/ali_5000 +``` +and you will get four files inside the folder `data/ali_5000`: + +``` +$ ls -lh data/ali_500 +total 546M +-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:06 test_clean.pt +-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:07 test_other.pt +-rw-r--r-- 1 kuangfangjun root 542M Sep 28 11:36 train-960.pt +-rw-r--r-- 1 kuangfangjun root 2.1M Sep 28 11:38 valid.pt +``` + +**Note**: It can take more than 3 hours to compute the alignment +for the training dataset, which contains 960 * 3 = 2880 hours of data. + +**Caution**: The model parameters in `conformer_ctc/ali.py` have to match those +in `conformer_ctc/train.py`. + +**Caution**: You have to set the parameter `preserve_id` to `True` for `CutMix`. +Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`. + +**TODO:** Add doc about how to use the extracted alignment in the other pull-request. diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py new file mode 100755 index 000000000..3d817a8f6 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import one_best_decoding +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + encode_supervisions, + get_alignments, + get_env_info, + save_alignments, + setup_logger, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_5000", + help="The lang dir", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--ali-dir", + type=str, + default="data/ali_500", + help="The experiment dir", + ) + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "lm_dir": Path("data/lm"), + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "output_beam": 10, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def compute_alignments( + model: torch.nn.Module, + dl: torch.utils.data.DataLoader, + params: AttributeDict, + graph_compiler: BpeCtcTrainingGraphCompiler, +) -> List[Tuple[str, List[int]]]: + """Compute the framewise alignments of a dataset. + + Args: + model: + The neural network model. + dl: + Dataloader containing the dataset. + params: + Parameters for computing alignments. + graph_compiler: + It converts token IDs to decoding graphs. + Returns: + Return a list of tuples. Each tuple contains two entries: + - Utterance ID + - Framewise alignments (token IDs) after subsampling + """ + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + num_cuts = 0 + + device = graph_compiler.device + ans = [] + for batch_idx, batch in enumerate(dl): + feature = batch["inputs"] + + # at entry, feature is [N, T, C] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + + cut_ids = [] + for cut in supervisions["cut"]: + assert len(cut.supervisions) == 1 + cut_ids.append(cut.id) + + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is [N, T, C] + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + # we need also to sort cut_ids as encode_supervisions() + # reorders "texts". + # In general, new2old is an identity map since lhotse sorts the returned + # cuts by duration in descending order + new2old = supervision_segments[:, 0].tolist() + cut_ids = [cut_ids[i] for i in new2old] + + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + lattice = k2.intersect_dense( + decoding_graph, + dense_fsa_vec, + params.output_beam, + ) + + best_path = one_best_decoding( + lattice=lattice, + use_double_scores=params.use_double_scores, + ) + + ali_ids = get_alignments(best_path) + assert len(ali_ids) == len(cut_ids) + ans += list(zip(cut_ids, ali_ids)) + + num_cuts += len(ali_ids) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert args.return_cuts is True + assert args.concatenate_cuts is False + if args.full_libri is False: + print("Changing --full-libri to True") + args.full_libri = True + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log/ali") + + logging.info("Computing alignment - started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to(device) + model.eval() + + librispeech = LibriSpeechAsrDataModule(args) + + train_dl = librispeech.train_dataloaders() + valid_dl = librispeech.valid_dataloaders() + test_dl = librispeech.test_dataloaders() # a list + + ali_dir = Path(params.ali_dir) + ali_dir.mkdir(exist_ok=True) + + enabled_datasets = { + "test_clean": test_dl[0], + "test_other": test_dl[1], + "train-960": train_dl, + "valid": valid_dl, + } + # For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to + # compute the alignments if you use --max-duration=500 + # + # There are 960 * 3 = 2880 hours data and it takes only + # 3 hours 40 minutes to get the alignment. + # The RTF is roughly: 3.67 / 2880 = 0.0012743 + # + # At the end, you would see + # 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa + # 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa + for name, dl in enabled_datasets.items(): + logging.info(f"Processing {name}") + if name == "train-960": + logging.info( + f"It will take about 3 hours 40 minutes for {name}, " + "which contains 960 * 3 = 2880 hours of data" + ) + alignments = compute_alignments( + model=model, + dl=dl, + params=params, + graph_compiler=graph_compiler, + ) + num_utt = len(alignments) + alignments = dict(alignments) + assert num_utt == len(alignments) + filename = ali_dir / f"{name}.pt" + save_alignments( + alignments=alignments, + subsampling_factor=params.subsampling_factor, + filename=filename, + ) + logging.info( + f"For dataset {name}, its alignments are saved to {filename}" + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 5a83dd39c..bddb832b0 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -43,6 +43,7 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -142,7 +143,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe", + default="data/lang_bpe_5000", help="The lang dir", ) @@ -167,6 +168,7 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "env_info": get_env_info(), } ) return params diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index 8241c84c1..79e026dac 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -65,7 +65,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe", + default="data/lang_bpe_5000", help="""It contains language related input files such as "lexicon.txt" """, ) diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 99bd9c017..beed6f73b 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -36,7 +36,7 @@ from icefall.decode import ( rescore_with_attention_decoder, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -256,7 +256,7 @@ def main(): params.num_decoder_layers = 0 params.update(vars(args)) - + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index d1cdfa8bb..ae088620f 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -24,16 +24,14 @@ from pathlib import Path from shutil import copyfile from typing import Optional, Tuple - import k2 import torch import torch.multiprocessing as mp import torch.nn as nn -from torch import Tensor - from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed +from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -48,6 +46,7 @@ from icefall.utils import ( AttributeDict, MetricsTracker, encode_supervisions, + get_env_info, setup_logger, str2bool, ) @@ -79,6 +78,13 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_5000", + help="lang directory", + ) + parser.add_argument( "--num-epochs", type=int, @@ -109,7 +115,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe", + default="data/lang_bpe_5000", help="""The lang dir It contains language related input files such as "lexicon.txt" @@ -185,7 +191,7 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 10, + "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # parameters for conformer @@ -204,6 +210,7 @@ def get_params() -> AttributeDict: "weight_decay": 1e-6, "lr_factor": 5.0, "warm_step": 80000, + "env_info": get_env_info(), } ) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index b536cb472..c3a09d682 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -41,6 +41,8 @@ dl_dir=$PWD/download # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( 5000 + 2000 + 1000 500 ) @@ -191,5 +193,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then ./local/compile_hlg.py --lang-dir $lang_dir done fi - -cd data && ln -sfv lang_bpe_5000 lang_bpe diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 229575db6..950eba438 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -21,10 +21,6 @@ from functools import lru_cache from pathlib import Path from typing import List, Union -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, @@ -36,6 +32,10 @@ from lhotse.dataset import ( SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool class LibriSpeechAsrDataModule(DataModule): @@ -162,7 +162,9 @@ class LibriSpeechAsrDataModule(DataModule): cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") logging.info("About to create train dataset") - transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] + transforms = [ + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ] if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 54c2f7a6b..d9d019743 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -39,6 +39,7 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -134,6 +135,7 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "env_info": get_env_info(), } ) return params diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index 2baeb6bba..e0d6a7a60 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -34,7 +34,7 @@ from icefall.decode import ( one_best_decoding, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -159,6 +159,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 4a8574019..7904b0e61 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -28,11 +28,10 @@ import torch import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim -from torch import Tensor - from asr_datamodule import LibriSpeechAsrDataModule from lhotse.utils import fix_random_seed from model import TdnnLstm +from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import StepLR @@ -47,6 +46,7 @@ from icefall.utils import ( AttributeDict, MetricsTracker, encode_supervisions, + get_env_info, setup_logger, str2bool, ) @@ -171,6 +171,7 @@ def get_params() -> AttributeDict: "beam_size": 10, "reduction": "sum", "use_double_scores": True, + "env_info": get_env_info(), } ) diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 57122235a..9df019bf5 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -17,6 +17,7 @@ from icefall.decode import get_lattice, one_best_decoding from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -256,6 +257,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() setup_logger(f"{params.exp_dir}/log/log-decode") logging.info("Decoding started") diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index 14220be19..75758b984 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -29,7 +29,7 @@ from model import Tdnn from torch.nn.utils.rnn import pad_sequence from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -116,6 +116,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index d414962ca..e24061fa1 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -11,10 +11,10 @@ import torch import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim -from torch import Tensor from asr_datamodule import YesNoAsrDataModule from lhotse.utils import fix_random_seed from model import Tdnn +from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -24,7 +24,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_env_info, + setup_logger, + str2bool, +) def get_parser(): @@ -465,6 +471,7 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() fix_random_seed(42) if world_size > 1: diff --git a/icefall/utils.py b/icefall/utils.py index 66aa5c601..287b917c5 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -17,18 +17,21 @@ import argparse -import logging import collections +import logging import os import subprocess +import sys from collections import defaultdict from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, TextIO, Tuple, Union +from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union import k2 +import k2.version import kaldialign +import lhotse import torch import torch.distributed as dist from torch.utils.tensorboard import SummaryWriter @@ -135,17 +138,82 @@ def setup_logger( logging.getLogger("").addHandler(console) -def get_env_info(): - """ - TODO: - """ +def get_git_sha1(): + git_commit = ( + subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + dirty_commit = ( + len( + subprocess.run( + ["git", "diff", "--shortstat"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + > 0 + ) + git_commit = ( + git_commit + "-dirty" if dirty_commit else git_commit + "-clean" + ) + return git_commit + + +def get_git_date(): + git_date = ( + subprocess.run( + ["git", "log", "-1", "--format=%ad", "--date=local"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + return git_date + + +def get_git_branch_name(): + git_date = ( + subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + return git_date + + +def get_env_info() -> Dict[str, Any]: + """Get the environment information.""" return { - "k2-git-sha1": None, - "k2-version": None, - "lhotse-version": None, - "torch-version": None, - "icefall-sha1": None, - "icefall-version": None, + "k2-version": k2.version.__version__, + "k2-build-type": k2.version.__build_type__, + "k2-with-cuda": k2.with_cuda, + "k2-git-sha1": k2.version.__git_sha1__, + "k2-git-date": k2.version.__git_date__, + "lhotse-version": lhotse.__version__, + "torch-cuda-available": torch.cuda.is_available(), + "torch-cuda-version": torch.version.cuda, + "python-version": sys.version[:3], + "icefall-git-branch": get_git_branch_name(), + "icefall-git-sha1": get_git_sha1(), + "icefall-git-date": get_git_date(), + "icefall-path": str(Path(__file__).resolve().parent.parent), + "k2-path": str(Path(k2.__file__).resolve()), + "lhotse-path": str(Path(lhotse.__file__).resolve()), } @@ -238,6 +306,73 @@ def get_texts( return aux_labels.tolist() +def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: + """Extract the token IDs (from best_paths.labels) from the best-path FSAs. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). + Returns: + Returns a list of lists of int, containing the token sequences we + decoded. For `ans[i]`, its length equals to the number of frames + after subsampling of the i-th utterance in the batch. + """ + # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here + label_shape = best_paths.arcs.shape().remove_axis(1) + # label_shape has axes [fsa][arc] + labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous()) + labels = labels.remove_values_eq(-1) + return labels.tolist() + + +def save_alignments( + alignments: Dict[str, List[int]], + subsampling_factor: int, + filename: str, +) -> None: + """Save alignments to a file. + + Args: + alignments: + A dict containing alignments. Keys of the dict are utterances and + values are the corresponding framewise alignments after subsampling. + subsampling_factor: + The subsampling factor of the model. + filename: + Path to save the alignments. + Returns: + Return None. + """ + ali_dict = { + "subsampling_factor": subsampling_factor, + "alignments": alignments, + } + torch.save(ali_dict, filename) + + +def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: + """Load alignments from a file. + + Args: + filename: + Path to the file containing alignment information. + The file should be saved by :func:`save_alignments`. + Returns: + Return a tuple containing: + - subsampling_factor: The subsampling_factor used to compute + the alignments. + - alignments: A dict containing utterances and their corresponding + framewise alignment, after subsampling. + """ + ali_dict = torch.load(filename) + subsampling_factor = ali_dict["subsampling_factor"] + alignments = ali_dict["alignments"] + return subsampling_factor, alignments + + def store_transcripts( filename: Pathlike, texts: Iterable[Tuple[str, str]] ) -> None: diff --git a/test/test_utils.py b/test/test_utils.py index 7ac52b289..b8c742c5a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -20,7 +20,12 @@ import k2 import pytest import torch -from icefall.utils import AttributeDict, encode_supervisions, get_texts +from icefall.utils import ( + AttributeDict, + encode_supervisions, + get_env_info, + get_texts, +) @pytest.fixture @@ -108,6 +113,7 @@ def test_attribute_dict(): assert s["b"] == 20 s.c = 100 assert s["c"] == 100 + assert hasattr(s, "a") assert hasattr(s, "b") assert getattr(s, "a") == 10 @@ -119,3 +125,8 @@ def test_attribute_dict(): del s.a except AttributeError as ex: print(f"Caught exception: {ex}") + + +def test_get_env_info(): + s = get_env_info() + print(s)