Extract framewise alignment information using CTC decoding (#39)

* Use new APIs with k2.RaggedTensor

* Fix style issues.

* Update the installation doc, saying it requires at least k2 v1.7

* Extract framewise alignment information using CTC decoding.

* Print environment information.

Print information about k2, lhotse, PyTorch, and icefall.

* Fix CI.

* Fix CI.

* Compute framewise alignment information of the LibriSpeech dataset.

* Update comments for the time to compute alignments of train-960.

* Preserve cut id in mix cut transformer.

* Minor fixes.

* Add doc about how to extract framewise alignments.
This commit is contained in:
Fangjun Kuang 2021-10-18 14:24:33 +08:00 committed by GitHub
parent bd7c2f7645
commit 4890e27b45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 582 additions and 38 deletions

View File

@ -46,10 +46,18 @@ jobs:
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install libnsdfile and libsox
if: startsWith(matrix.os, 'ubuntu')
run: |
sudo apt update
sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg
sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip pytest python3 -m pip install --upgrade pip pytest
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
pip install git+https://github.com/lhotse-speech/lhotse
# icefall requirements # icefall requirements
pip install -r requirements.txt pip install -r requirements.txt
@ -88,4 +96,3 @@ jobs:
# runt tests for conformer ctc # runt tests for conformer ctc
cd egs/librispeech/ASR/conformer_ctc cd egs/librispeech/ASR/conformer_ctc
pytest pytest

View File

@ -38,14 +38,16 @@ python conformer_ctc/train.py --bucketing-sampler True \
--concatenate-cuts False \ --concatenate-cuts False \
--max-duration 200 \ --max-duration 200 \
--full-libri True \ --full-libri True \
--world-size 4 --world-size 4 \
--lang-dir data/lang_bpe_5000
python conformer_ctc/decode.py --nbest-scale 0.5 \ python conformer_ctc/decode.py --nbest-scale 0.5 \
--epoch 34 \ --epoch 34 \
--avg 20 \ --avg 20 \
--method attention-decoder \ --method attention-decoder \
--max-duration 20 \ --max-duration 20 \
--num-paths 100 --num-paths 100 \
--lang-dir data/lang_bpe_5000
``` ```
### LibriSpeech training results (Tdnn-Lstm) ### LibriSpeech training results (Tdnn-Lstm)

View File

@ -1,3 +1,53 @@
## Introduction
Please visit Please visit
<https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html> <https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html>
for how to run this recipe. for how to run this recipe.
## How to compute framewise alignment information
### Step 1: Train a model
Please use `conformer_ctc/train.py` to train a model.
See <https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html>
for how to do it.
### Step 2: Compute framewise alignment
Run
```
# Choose a checkpoint and determine the number of checkpoints to average
epoch=30
avg=15
./conformer_ctc/ali.py \
--epoch $epoch \
--avg $avg \
--max-duration 500 \
--bucketing-sampler 0 \
--full-libri 1 \
--exp-dir conformer_ctc/exp \
--lang-dir data/lang_bpe_5000 \
--ali-dir data/ali_5000
```
and you will get four files inside the folder `data/ali_5000`:
```
$ ls -lh data/ali_500
total 546M
-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:06 test_clean.pt
-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:07 test_other.pt
-rw-r--r-- 1 kuangfangjun root 542M Sep 28 11:36 train-960.pt
-rw-r--r-- 1 kuangfangjun root 2.1M Sep 28 11:38 valid.pt
```
**Note**: It can take more than 3 hours to compute the alignment
for the training dataset, which contains 960 * 3 = 2880 hours of data.
**Caution**: The model parameters in `conformer_ctc/ali.py` have to match those
in `conformer_ctc/train.py`.
**Caution**: You have to set the parameter `preserve_id` to `True` for `CutMix`.
Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`.
**TODO:** Add doc about how to use the extracted alignment in the other pull-request.

View File

@ -0,0 +1,314 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import one_best_decoding
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
encode_supervisions,
get_alignments,
get_env_info,
save_alignments,
setup_logger,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_5000",
help="The lang dir",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--ali-dir",
type=str,
default="data/ali_500",
help="The experiment dir",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"output_beam": 10,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def compute_alignments(
model: torch.nn.Module,
dl: torch.utils.data.DataLoader,
params: AttributeDict,
graph_compiler: BpeCtcTrainingGraphCompiler,
) -> List[Tuple[str, List[int]]]:
"""Compute the framewise alignments of a dataset.
Args:
model:
The neural network model.
dl:
Dataloader containing the dataset.
params:
Parameters for computing alignments.
graph_compiler:
It converts token IDs to decoding graphs.
Returns:
Return a list of tuples. Each tuple contains two entries:
- Utterance ID
- Framewise alignments (token IDs) after subsampling
"""
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
num_cuts = 0
device = graph_compiler.device
ans = []
for batch_idx, batch in enumerate(dl):
feature = batch["inputs"]
# at entry, feature is [N, T, C]
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
cut_ids = []
for cut in supervisions["cut"]:
assert len(cut.supervisions) == 1
cut_ids.append(cut.id)
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
# we need also to sort cut_ids as encode_supervisions()
# reorders "texts".
# In general, new2old is an identity map since lhotse sorts the returned
# cuts by duration in descending order
new2old = supervision_segments[:, 0].tolist()
cut_ids = [cut_ids[i] for i in new2old]
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
lattice = k2.intersect_dense(
decoding_graph,
dense_fsa_vec,
params.output_beam,
)
best_path = one_best_decoding(
lattice=lattice,
use_double_scores=params.use_double_scores,
)
ali_ids = get_alignments(best_path)
assert len(ali_ids) == len(cut_ids)
ans += list(zip(cut_ids, ali_ids))
num_cuts += len(ali_ids)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return ans
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert args.return_cuts is True
assert args.concatenate_cuts is False
if args.full_libri is False:
print("Changing --full-libri to True")
args.full_libri = True
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/ali")
logging.info("Computing alignment - started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
logging.info("About to create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
test_dl = librispeech.test_dataloaders() # a list
ali_dir = Path(params.ali_dir)
ali_dir.mkdir(exist_ok=True)
enabled_datasets = {
"test_clean": test_dl[0],
"test_other": test_dl[1],
"train-960": train_dl,
"valid": valid_dl,
}
# For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to
# compute the alignments if you use --max-duration=500
#
# There are 960 * 3 = 2880 hours data and it takes only
# 3 hours 40 minutes to get the alignment.
# The RTF is roughly: 3.67 / 2880 = 0.0012743
#
# At the end, you would see
# 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa
# 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa
for name, dl in enabled_datasets.items():
logging.info(f"Processing {name}")
if name == "train-960":
logging.info(
f"It will take about 3 hours 40 minutes for {name}, "
"which contains 960 * 3 = 2880 hours of data"
)
alignments = compute_alignments(
model=model,
dl=dl,
params=params,
graph_compiler=graph_compiler,
)
num_utt = len(alignments)
alignments = dict(alignments)
assert num_utt == len(alignments)
filename = ali_dir / f"{name}.pt"
save_alignments(
alignments=alignments,
subsampling_factor=params.subsampling_factor,
filename=filename,
)
logging.info(
f"For dataset {name}, its alignments are saved to {filename}"
)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -43,6 +43,7 @@ from icefall.decode import (
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
@ -142,7 +143,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=str, type=str,
default="data/lang_bpe", default="data/lang_bpe_5000",
help="The lang dir", help="The lang dir",
) )
@ -167,6 +168,7 @@ def get_params() -> AttributeDict:
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
"env_info": get_env_info(),
} }
) )
return params return params

View File

@ -65,7 +65,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=str, type=str,
default="data/lang_bpe", default="data/lang_bpe_5000",
help="""It contains language related input files such as "lexicon.txt" help="""It contains language related input files such as "lexicon.txt"
""", """,
) )

View File

@ -36,7 +36,7 @@ from icefall.decode import (
rescore_with_attention_decoder, rescore_with_attention_decoder,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.utils import AttributeDict, get_texts from icefall.utils import AttributeDict, get_env_info, get_texts
def get_parser(): def get_parser():
@ -256,7 +256,7 @@ def main():
params.num_decoder_layers = 0 params.num_decoder_layers = 0
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -24,16 +24,14 @@ from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -48,6 +46,7 @@ from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )
@ -79,6 +78,13 @@ def get_parser():
help="Should various information be logged in tensorboard.", help="Should various information be logged in tensorboard.",
) )
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_5000",
help="lang directory",
)
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
@ -109,7 +115,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=str, type=str,
default="data/lang_bpe", default="data/lang_bpe_5000",
help="""The lang dir help="""The lang dir
It contains language related input files such as It contains language related input files such as
"lexicon.txt" "lexicon.txt"
@ -185,7 +191,7 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 10, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, "valid_interval": 3000,
# parameters for conformer # parameters for conformer
@ -204,6 +210,7 @@ def get_params() -> AttributeDict:
"weight_decay": 1e-6, "weight_decay": 1e-6,
"lr_factor": 5.0, "lr_factor": 5.0,
"warm_step": 80000, "warm_step": 80000,
"env_info": get_env_info(),
} }
) )

View File

@ -41,6 +41,8 @@ dl_dir=$PWD/download
# data/lang_bpe_yyy if the array contains xxx, yyy # data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=( vocab_sizes=(
5000 5000
2000
1000
500 500
) )
@ -191,5 +193,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
./local/compile_hlg.py --lang-dir $lang_dir ./local/compile_hlg.py --lang-dir $lang_dir
done done
fi fi
cd data && ln -sfv lang_bpe_5000 lang_bpe

View File

@ -21,10 +21,6 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
@ -36,6 +32,10 @@ from lhotse.dataset import (
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class LibriSpeechAsrDataModule(DataModule): class LibriSpeechAsrDataModule(DataModule):
@ -162,7 +162,9 @@ class LibriSpeechAsrDataModule(DataModule):
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
logging.info("About to create train dataset") logging.info("About to create train dataset")
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] transforms = [
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
]
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
logging.info( logging.info(
f"Using cut concatenation with duration factor " f"Using cut concatenation with duration factor "

View File

@ -39,6 +39,7 @@ from icefall.decode import (
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
@ -134,6 +135,7 @@ def get_params() -> AttributeDict:
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
"env_info": get_env_info(),
} }
) )
return params return params

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.utils import AttributeDict, get_texts from icefall.utils import AttributeDict, get_env_info, get_texts
def get_parser(): def get_parser():
@ -159,6 +159,7 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -28,11 +28,10 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch import Tensor
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import TdnnLstm from model import TdnnLstm
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
@ -47,6 +46,7 @@ from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )
@ -171,6 +171,7 @@ def get_params() -> AttributeDict:
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
"use_double_scores": True, "use_double_scores": True,
"env_info": get_env_info(),
} }
) )

View File

@ -17,6 +17,7 @@ from icefall.decode import get_lattice, one_best_decoding
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
@ -256,6 +257,7 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
setup_logger(f"{params.exp_dir}/log/log-decode") setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started") logging.info("Decoding started")

View File

@ -29,7 +29,7 @@ from model import Tdnn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts from icefall.utils import AttributeDict, get_env_info, get_texts
def get_parser(): def get_parser():
@ -116,6 +116,7 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -11,10 +11,10 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch import Tensor
from asr_datamodule import YesNoAsrDataModule from asr_datamodule import YesNoAsrDataModule
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Tdnn from model import Tdnn
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -24,7 +24,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import (
AttributeDict,
MetricsTracker,
get_env_info,
setup_logger,
str2bool,
)
def get_parser(): def get_parser():
@ -465,6 +471,7 @@ def run(rank, world_size, args):
""" """
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
fix_random_seed(42) fix_random_seed(42)
if world_size > 1: if world_size > 1:

View File

@ -17,18 +17,21 @@
import argparse import argparse
import logging
import collections import collections
import logging
import os import os
import subprocess import subprocess
import sys
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple, Union from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union
import k2 import k2
import k2.version
import kaldialign import kaldialign
import lhotse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -135,17 +138,82 @@ def setup_logger(
logging.getLogger("").addHandler(console) logging.getLogger("").addHandler(console)
def get_env_info(): def get_git_sha1():
""" git_commit = (
TODO: subprocess.run(
""" ["git", "rev-parse", "--short", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
dirty_commit = (
len(
subprocess.run(
["git", "diff", "--shortstat"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
> 0
)
git_commit = (
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
)
return git_commit
def get_git_date():
git_date = (
subprocess.run(
["git", "log", "-1", "--format=%ad", "--date=local"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_git_branch_name():
git_date = (
subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_env_info() -> Dict[str, Any]:
"""Get the environment information."""
return { return {
"k2-git-sha1": None, "k2-version": k2.version.__version__,
"k2-version": None, "k2-build-type": k2.version.__build_type__,
"lhotse-version": None, "k2-with-cuda": k2.with_cuda,
"torch-version": None, "k2-git-sha1": k2.version.__git_sha1__,
"icefall-sha1": None, "k2-git-date": k2.version.__git_date__,
"icefall-version": None, "lhotse-version": lhotse.__version__,
"torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3],
"icefall-git-branch": get_git_branch_name(),
"icefall-git-sha1": get_git_sha1(),
"icefall-git-date": get_git_date(),
"icefall-path": str(Path(__file__).resolve().parent.parent),
"k2-path": str(Path(k2.__file__).resolve()),
"lhotse-path": str(Path(lhotse.__file__).resolve()),
} }
@ -238,6 +306,73 @@ def get_texts(
return aux_labels.tolist() return aux_labels.tolist()
def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
"""Extract the token IDs (from best_paths.labels) from the best-path FSAs.
Args:
best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
Returns:
Returns a list of lists of int, containing the token sequences we
decoded. For `ans[i]`, its length equals to the number of frames
after subsampling of the i-th utterance in the batch.
"""
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
label_shape = best_paths.arcs.shape().remove_axis(1)
# label_shape has axes [fsa][arc]
labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous())
labels = labels.remove_values_eq(-1)
return labels.tolist()
def save_alignments(
alignments: Dict[str, List[int]],
subsampling_factor: int,
filename: str,
) -> None:
"""Save alignments to a file.
Args:
alignments:
A dict containing alignments. Keys of the dict are utterances and
values are the corresponding framewise alignments after subsampling.
subsampling_factor:
The subsampling factor of the model.
filename:
Path to save the alignments.
Returns:
Return None.
"""
ali_dict = {
"subsampling_factor": subsampling_factor,
"alignments": alignments,
}
torch.save(ali_dict, filename)
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
"""Load alignments from a file.
Args:
filename:
Path to the file containing alignment information.
The file should be saved by :func:`save_alignments`.
Returns:
Return a tuple containing:
- subsampling_factor: The subsampling_factor used to compute
the alignments.
- alignments: A dict containing utterances and their corresponding
framewise alignment, after subsampling.
"""
ali_dict = torch.load(filename)
subsampling_factor = ali_dict["subsampling_factor"]
alignments = ali_dict["alignments"]
return subsampling_factor, alignments
def store_transcripts( def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str]] filename: Pathlike, texts: Iterable[Tuple[str, str]]
) -> None: ) -> None:

View File

@ -20,7 +20,12 @@ import k2
import pytest import pytest
import torch import torch
from icefall.utils import AttributeDict, encode_supervisions, get_texts from icefall.utils import (
AttributeDict,
encode_supervisions,
get_env_info,
get_texts,
)
@pytest.fixture @pytest.fixture
@ -108,6 +113,7 @@ def test_attribute_dict():
assert s["b"] == 20 assert s["b"] == 20
s.c = 100 s.c = 100
assert s["c"] == 100 assert s["c"] == 100
assert hasattr(s, "a") assert hasattr(s, "a")
assert hasattr(s, "b") assert hasattr(s, "b")
assert getattr(s, "a") == 10 assert getattr(s, "a") == 10
@ -119,3 +125,8 @@ def test_attribute_dict():
del s.a del s.a
except AttributeError as ex: except AttributeError as ex:
print(f"Caught exception: {ex}") print(f"Caught exception: {ex}")
def test_get_env_info():
s = get_env_info()
print(s)