mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Extract framewise alignment information using CTC decoding (#39)
* Use new APIs with k2.RaggedTensor * Fix style issues. * Update the installation doc, saying it requires at least k2 v1.7 * Extract framewise alignment information using CTC decoding. * Print environment information. Print information about k2, lhotse, PyTorch, and icefall. * Fix CI. * Fix CI. * Compute framewise alignment information of the LibriSpeech dataset. * Update comments for the time to compute alignments of train-960. * Preserve cut id in mix cut transformer. * Minor fixes. * Add doc about how to extract framewise alignments.
This commit is contained in:
parent
bd7c2f7645
commit
4890e27b45
9
.github/workflows/test.yml
vendored
9
.github/workflows/test.yml
vendored
@ -46,10 +46,18 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install libnsdfile and libsox
|
||||
if: startsWith(matrix.os, 'ubuntu')
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg
|
||||
sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip pytest
|
||||
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
||||
pip install git+https://github.com/lhotse-speech/lhotse
|
||||
# icefall requirements
|
||||
pip install -r requirements.txt
|
||||
|
||||
@ -88,4 +96,3 @@ jobs:
|
||||
# runt tests for conformer ctc
|
||||
cd egs/librispeech/ASR/conformer_ctc
|
||||
pytest
|
||||
|
||||
|
@ -38,14 +38,16 @@ python conformer_ctc/train.py --bucketing-sampler True \
|
||||
--concatenate-cuts False \
|
||||
--max-duration 200 \
|
||||
--full-libri True \
|
||||
--world-size 4
|
||||
--world-size 4 \
|
||||
--lang-dir data/lang_bpe_5000
|
||||
|
||||
python conformer_ctc/decode.py --nbest-scale 0.5 \
|
||||
--epoch 34 \
|
||||
--avg 20 \
|
||||
--method attention-decoder \
|
||||
--max-duration 20 \
|
||||
--num-paths 100
|
||||
--num-paths 100 \
|
||||
--lang-dir data/lang_bpe_5000
|
||||
```
|
||||
|
||||
### LibriSpeech training results (Tdnn-Lstm)
|
||||
|
@ -1,3 +1,53 @@
|
||||
## Introduction
|
||||
|
||||
Please visit
|
||||
<https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html>
|
||||
for how to run this recipe.
|
||||
|
||||
## How to compute framewise alignment information
|
||||
|
||||
### Step 1: Train a model
|
||||
|
||||
Please use `conformer_ctc/train.py` to train a model.
|
||||
See <https://icefall.readthedocs.io/en/latest/recipes/librispeech/conformer_ctc.html>
|
||||
for how to do it.
|
||||
|
||||
### Step 2: Compute framewise alignment
|
||||
|
||||
Run
|
||||
|
||||
```
|
||||
# Choose a checkpoint and determine the number of checkpoints to average
|
||||
epoch=30
|
||||
avg=15
|
||||
./conformer_ctc/ali.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--max-duration 500 \
|
||||
--bucketing-sampler 0 \
|
||||
--full-libri 1 \
|
||||
--exp-dir conformer_ctc/exp \
|
||||
--lang-dir data/lang_bpe_5000 \
|
||||
--ali-dir data/ali_5000
|
||||
```
|
||||
and you will get four files inside the folder `data/ali_5000`:
|
||||
|
||||
```
|
||||
$ ls -lh data/ali_500
|
||||
total 546M
|
||||
-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:06 test_clean.pt
|
||||
-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:07 test_other.pt
|
||||
-rw-r--r-- 1 kuangfangjun root 542M Sep 28 11:36 train-960.pt
|
||||
-rw-r--r-- 1 kuangfangjun root 2.1M Sep 28 11:38 valid.pt
|
||||
```
|
||||
|
||||
**Note**: It can take more than 3 hours to compute the alignment
|
||||
for the training dataset, which contains 960 * 3 = 2880 hours of data.
|
||||
|
||||
**Caution**: The model parameters in `conformer_ctc/ali.py` have to match those
|
||||
in `conformer_ctc/train.py`.
|
||||
|
||||
**Caution**: You have to set the parameter `preserve_id` to `True` for `CutMix`.
|
||||
Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`.
|
||||
|
||||
**TODO:** Add doc about how to use the extracted alignment in the other pull-request.
|
||||
|
314
egs/librispeech/ASR/conformer_ctc/ali.py
Executable file
314
egs/librispeech/ASR/conformer_ctc/ali.py
Executable file
@ -0,0 +1,314 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
get_alignments,
|
||||
get_env_info,
|
||||
save_alignments,
|
||||
setup_logger,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=34,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_5000",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="conformer_ctc/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ali-dir",
|
||||
type=str,
|
||||
default="data/ali_500",
|
||||
help="The experiment dir",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"lm_dir": Path("data/lm"),
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
"attention_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"num_decoder_layers": 6,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
"output_beam": 10,
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def compute_alignments(
|
||||
model: torch.nn.Module,
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
) -> List[Tuple[str, List[int]]]:
|
||||
"""Compute the framewise alignments of a dataset.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The neural network model.
|
||||
dl:
|
||||
Dataloader containing the dataset.
|
||||
params:
|
||||
Parameters for computing alignments.
|
||||
graph_compiler:
|
||||
It converts token IDs to decoding graphs.
|
||||
Returns:
|
||||
Return a list of tuples. Each tuple contains two entries:
|
||||
- Utterance ID
|
||||
- Framewise alignments (token IDs) after subsampling
|
||||
"""
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
num_cuts = 0
|
||||
|
||||
device = graph_compiler.device
|
||||
ans = []
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
feature = batch["inputs"]
|
||||
|
||||
# at entry, feature is [N, T, C]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
cut_ids = []
|
||||
for cut in supervisions["cut"]:
|
||||
assert len(cut.supervisions) == 1
|
||||
cut_ids.append(cut.id)
|
||||
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, T, C]
|
||||
supervision_segments, texts = encode_supervisions(
|
||||
supervisions, subsampling_factor=params.subsampling_factor
|
||||
)
|
||||
# we need also to sort cut_ids as encode_supervisions()
|
||||
# reorders "texts".
|
||||
# In general, new2old is an identity map since lhotse sorts the returned
|
||||
# cuts by duration in descending order
|
||||
new2old = supervision_segments[:, 0].tolist()
|
||||
cut_ids = [cut_ids[i] for i in new2old]
|
||||
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
decoding_graph = graph_compiler.compile(token_ids)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
supervision_segments,
|
||||
allow_truncate=params.subsampling_factor - 1,
|
||||
)
|
||||
|
||||
lattice = k2.intersect_dense(
|
||||
decoding_graph,
|
||||
dense_fsa_vec,
|
||||
params.output_beam,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
|
||||
ali_ids = get_alignments(best_path)
|
||||
assert len(ali_ids) == len(cut_ids)
|
||||
ans += list(zip(cut_ids, ali_ids))
|
||||
|
||||
num_cuts += len(ali_ids)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.return_cuts is True
|
||||
assert args.concatenate_cuts is False
|
||||
if args.full_libri is False:
|
||||
print("Changing --full-libri to True")
|
||||
args.full_libri = True
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/ali")
|
||||
|
||||
logging.info("Computing alignment - started")
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.attention_dim,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
train_dl = librispeech.train_dataloaders()
|
||||
valid_dl = librispeech.valid_dataloaders()
|
||||
test_dl = librispeech.test_dataloaders() # a list
|
||||
|
||||
ali_dir = Path(params.ali_dir)
|
||||
ali_dir.mkdir(exist_ok=True)
|
||||
|
||||
enabled_datasets = {
|
||||
"test_clean": test_dl[0],
|
||||
"test_other": test_dl[1],
|
||||
"train-960": train_dl,
|
||||
"valid": valid_dl,
|
||||
}
|
||||
# For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to
|
||||
# compute the alignments if you use --max-duration=500
|
||||
#
|
||||
# There are 960 * 3 = 2880 hours data and it takes only
|
||||
# 3 hours 40 minutes to get the alignment.
|
||||
# The RTF is roughly: 3.67 / 2880 = 0.0012743
|
||||
#
|
||||
# At the end, you would see
|
||||
# 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa
|
||||
# 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa
|
||||
for name, dl in enabled_datasets.items():
|
||||
logging.info(f"Processing {name}")
|
||||
if name == "train-960":
|
||||
logging.info(
|
||||
f"It will take about 3 hours 40 minutes for {name}, "
|
||||
"which contains 960 * 3 = 2880 hours of data"
|
||||
)
|
||||
alignments = compute_alignments(
|
||||
model=model,
|
||||
dl=dl,
|
||||
params=params,
|
||||
graph_compiler=graph_compiler,
|
||||
)
|
||||
num_utt = len(alignments)
|
||||
alignments = dict(alignments)
|
||||
assert num_utt == len(alignments)
|
||||
filename = ali_dir / f"{name}.pt"
|
||||
save_alignments(
|
||||
alignments=alignments,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
filename=filename,
|
||||
)
|
||||
logging.info(
|
||||
f"For dataset {name}, its alignments are saved to {filename}"
|
||||
)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -43,6 +43,7 @@ from icefall.decode import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_env_info,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
@ -142,7 +143,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe",
|
||||
default="data/lang_bpe_5000",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
@ -167,6 +168,7 @@ def get_params() -> AttributeDict:
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
@ -65,7 +65,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe",
|
||||
default="data/lang_bpe_5000",
|
||||
help="""It contains language related input files such as "lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
@ -36,7 +36,7 @@ from icefall.decode import (
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -256,7 +256,7 @@ def main():
|
||||
params.num_decoder_layers = 0
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
params["env_info"] = get_env_info()
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
@ -24,16 +24,14 @@ from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@ -48,6 +46,7 @@ from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
get_env_info,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
@ -79,6 +78,13 @@ def get_parser():
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_5000",
|
||||
help="lang directory",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
@ -109,7 +115,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe",
|
||||
default="data/lang_bpe_5000",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
@ -185,7 +191,7 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 10,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000,
|
||||
# parameters for conformer
|
||||
@ -204,6 +210,7 @@ def get_params() -> AttributeDict:
|
||||
"weight_decay": 1e-6,
|
||||
"lr_factor": 5.0,
|
||||
"warm_step": 80000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -41,6 +41,8 @@ dl_dir=$PWD/download
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
5000
|
||||
2000
|
||||
1000
|
||||
500
|
||||
)
|
||||
|
||||
@ -191,5 +193,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
./local/compile_hlg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
||||
cd data && ln -sfv lang_bpe_5000 lang_bpe
|
||||
|
@ -21,10 +21,6 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.dataset.datamodule import DataModule
|
||||
from icefall.utils import str2bool
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
@ -36,6 +32,10 @@ from lhotse.dataset import (
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.dataset.datamodule import DataModule
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule(DataModule):
|
||||
@ -162,7 +162,9 @@ class LibriSpeechAsrDataModule(DataModule):
|
||||
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
|
||||
transforms = [
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
]
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
|
@ -39,6 +39,7 @@ from icefall.decode import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_env_info,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
@ -134,6 +135,7 @@ def get_params() -> AttributeDict:
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
@ -34,7 +34,7 @@ from icefall.decode import (
|
||||
one_best_decoding,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -159,6 +159,7 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params["env_info"] = get_env_info()
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
@ -28,11 +28,10 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch import Tensor
|
||||
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import TdnnLstm
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
@ -47,6 +46,7 @@ from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
get_env_info,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
@ -171,6 +171,7 @@ def get_params() -> AttributeDict:
|
||||
"beam_size": 10,
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -17,6 +17,7 @@ from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_env_info,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
@ -256,6 +257,7 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params["env_info"] = get_env_info()
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-decode")
|
||||
logging.info("Decoding started")
|
||||
|
@ -29,7 +29,7 @@ from model import Tdnn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -116,6 +116,7 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params["env_info"] = get_env_info()
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
@ -11,10 +11,10 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch import Tensor
|
||||
from asr_datamodule import YesNoAsrDataModule
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Tdnn
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@ -24,7 +24,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_env_info,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -465,6 +471,7 @@ def run(rank, world_size, args):
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params["env_info"] = get_env_info()
|
||||
|
||||
fix_random_seed(42)
|
||||
if world_size > 1:
|
||||
|
159
icefall/utils.py
159
icefall/utils.py
@ -17,18 +17,21 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union
|
||||
|
||||
import k2
|
||||
import k2.version
|
||||
import kaldialign
|
||||
import lhotse
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@ -135,17 +138,82 @@ def setup_logger(
|
||||
logging.getLogger("").addHandler(console)
|
||||
|
||||
|
||||
def get_env_info():
|
||||
"""
|
||||
TODO:
|
||||
"""
|
||||
def get_git_sha1():
|
||||
git_commit = (
|
||||
subprocess.run(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
dirty_commit = (
|
||||
len(
|
||||
subprocess.run(
|
||||
["git", "diff", "--shortstat"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
> 0
|
||||
)
|
||||
git_commit = (
|
||||
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
|
||||
)
|
||||
return git_commit
|
||||
|
||||
|
||||
def get_git_date():
|
||||
git_date = (
|
||||
subprocess.run(
|
||||
["git", "log", "-1", "--format=%ad", "--date=local"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
return git_date
|
||||
|
||||
|
||||
def get_git_branch_name():
|
||||
git_date = (
|
||||
subprocess.run(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
return git_date
|
||||
|
||||
|
||||
def get_env_info() -> Dict[str, Any]:
|
||||
"""Get the environment information."""
|
||||
return {
|
||||
"k2-git-sha1": None,
|
||||
"k2-version": None,
|
||||
"lhotse-version": None,
|
||||
"torch-version": None,
|
||||
"icefall-sha1": None,
|
||||
"icefall-version": None,
|
||||
"k2-version": k2.version.__version__,
|
||||
"k2-build-type": k2.version.__build_type__,
|
||||
"k2-with-cuda": k2.with_cuda,
|
||||
"k2-git-sha1": k2.version.__git_sha1__,
|
||||
"k2-git-date": k2.version.__git_date__,
|
||||
"lhotse-version": lhotse.__version__,
|
||||
"torch-cuda-available": torch.cuda.is_available(),
|
||||
"torch-cuda-version": torch.version.cuda,
|
||||
"python-version": sys.version[:3],
|
||||
"icefall-git-branch": get_git_branch_name(),
|
||||
"icefall-git-sha1": get_git_sha1(),
|
||||
"icefall-git-date": get_git_date(),
|
||||
"icefall-path": str(Path(__file__).resolve().parent.parent),
|
||||
"k2-path": str(Path(k2.__file__).resolve()),
|
||||
"lhotse-path": str(Path(lhotse.__file__).resolve()),
|
||||
}
|
||||
|
||||
|
||||
@ -238,6 +306,73 @@ def get_texts(
|
||||
return aux_labels.tolist()
|
||||
|
||||
|
||||
def get_alignments(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
"""Extract the token IDs (from best_paths.labels) from the best-path FSAs.
|
||||
|
||||
Args:
|
||||
best_paths:
|
||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful).
|
||||
Returns:
|
||||
Returns a list of lists of int, containing the token sequences we
|
||||
decoded. For `ans[i]`, its length equals to the number of frames
|
||||
after subsampling of the i-th utterance in the batch.
|
||||
"""
|
||||
# arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
|
||||
label_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
# label_shape has axes [fsa][arc]
|
||||
labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous())
|
||||
labels = labels.remove_values_eq(-1)
|
||||
return labels.tolist()
|
||||
|
||||
|
||||
def save_alignments(
|
||||
alignments: Dict[str, List[int]],
|
||||
subsampling_factor: int,
|
||||
filename: str,
|
||||
) -> None:
|
||||
"""Save alignments to a file.
|
||||
|
||||
Args:
|
||||
alignments:
|
||||
A dict containing alignments. Keys of the dict are utterances and
|
||||
values are the corresponding framewise alignments after subsampling.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model.
|
||||
filename:
|
||||
Path to save the alignments.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
ali_dict = {
|
||||
"subsampling_factor": subsampling_factor,
|
||||
"alignments": alignments,
|
||||
}
|
||||
torch.save(ali_dict, filename)
|
||||
|
||||
|
||||
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
|
||||
"""Load alignments from a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Path to the file containing alignment information.
|
||||
The file should be saved by :func:`save_alignments`.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- subsampling_factor: The subsampling_factor used to compute
|
||||
the alignments.
|
||||
- alignments: A dict containing utterances and their corresponding
|
||||
framewise alignment, after subsampling.
|
||||
"""
|
||||
ali_dict = torch.load(filename)
|
||||
subsampling_factor = ali_dict["subsampling_factor"]
|
||||
alignments = ali_dict["alignments"]
|
||||
return subsampling_factor, alignments
|
||||
|
||||
|
||||
def store_transcripts(
|
||||
filename: Pathlike, texts: Iterable[Tuple[str, str]]
|
||||
) -> None:
|
||||
|
@ -20,7 +20,12 @@ import k2
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from icefall.utils import AttributeDict, encode_supervisions, get_texts
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
get_env_info,
|
||||
get_texts,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -108,6 +113,7 @@ def test_attribute_dict():
|
||||
assert s["b"] == 20
|
||||
s.c = 100
|
||||
assert s["c"] == 100
|
||||
|
||||
assert hasattr(s, "a")
|
||||
assert hasattr(s, "b")
|
||||
assert getattr(s, "a") == 10
|
||||
@ -119,3 +125,8 @@ def test_attribute_dict():
|
||||
del s.a
|
||||
except AttributeError as ex:
|
||||
print(f"Caught exception: {ex}")
|
||||
|
||||
|
||||
def test_get_env_info():
|
||||
s = get_env_info()
|
||||
print(s)
|
||||
|
Loading…
x
Reference in New Issue
Block a user