mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add torch script support for Aishell and update documents (#124)
* Add aishell recipe * Remove unnecessary code and update docs * adapt to k2 v1.7, add docs and results * Update conformer ctc model * Update docs, pretrained.py & results * Fix code style * Fix code style * Fix code style * Minor fix * Minor fix * Fix pretrained.py * Update pretrained model & corresponding docs * Export torch script model for Aishell * Add C++ deployment docs * Minor fixes * Fix unit test * Update Readme
This commit is contained in:
parent
30c43b7f69
commit
4151cca147
31
README.md
31
README.md
@ -12,10 +12,11 @@ for installation.
|
||||
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/index.html>
|
||||
for more information.
|
||||
|
||||
We provide three recipes at present:
|
||||
We provide four recipes at present:
|
||||
|
||||
- [yesno][yesno]
|
||||
- [LibriSpeech][librispeech]
|
||||
- [Aishell][aishell]
|
||||
- [TIMIT][timit]
|
||||
|
||||
### yesno
|
||||
@ -57,6 +58,31 @@ The WER for this model is:
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
|
||||
|
||||
### Aishell
|
||||
|
||||
We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc]
|
||||
and [TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc].
|
||||
|
||||
#### Conformer CTC Model
|
||||
|
||||
The best CER we currently have is:
|
||||
|
||||
| | test |
|
||||
|-----|------|
|
||||
| CER | 4.26 |
|
||||
|
||||
|
||||
We provide a Colab notebook to run a pre-trained conformer CTC model: [](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing)
|
||||
|
||||
#### TDNN LSTM CTC Model
|
||||
|
||||
The CER for this model is:
|
||||
|
||||
| | test |
|
||||
|-----|-------|
|
||||
| CER | 10.16 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z?usp=sharing)
|
||||
|
||||
### TIMIT
|
||||
|
||||
@ -99,9 +125,12 @@ Please see: [ How to prepare data for training and decoding
|
||||
- (2) How to start the training, either with a single GPU or multiple GPUs
|
||||
- (3) How to do decoding after training, with 1best and attention decoder rescoring
|
||||
- (3) How to do decoding after training, with ctc-decoding, 1best and attention decoder rescoring
|
||||
- (4) How to use a pre-trained model, provided by us
|
||||
|
||||
Data preparation
|
||||
@ -623,3 +623,125 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained
|
||||
|
||||
**Congratulations!** You have finished the aishell ASR recipe with
|
||||
conformer CTC models in ``icefall``.
|
||||
|
||||
|
||||
If you want to deploy your trained model in C++, please read the following section.
|
||||
|
||||
Deployment with C++
|
||||
-------------------
|
||||
|
||||
This section describes how to deploy the pre-trained model in C++, without
|
||||
Python dependencies.
|
||||
|
||||
.. HINT::
|
||||
|
||||
At present, it does NOT support streaming decoding.
|
||||
|
||||
First, let us compile k2 from source:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd $HOME
|
||||
$ git clone https://github.com/k2-fsa/k2
|
||||
$ cd k2
|
||||
$ git checkout v2.0-pre
|
||||
|
||||
.. CAUTION::
|
||||
|
||||
You have to switch to the branch ``v2.0-pre``!
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ mkdir build-release
|
||||
$ cd build-release
|
||||
$ cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||
$ make -j hlg_decode
|
||||
|
||||
# You will find four binaries in `./bin`, i.e. ./bin/hlg_decode,
|
||||
|
||||
Now you are ready to go!
|
||||
|
||||
Assume you have run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd k2/build-release
|
||||
$ ln -s /path/to/icefall-asr-aishell-conformer-ctc ./
|
||||
|
||||
To view the usage of ``./bin/hlg_decode``, run:
|
||||
|
||||
.. code-block::
|
||||
|
||||
$ ./bin/hlg_decode
|
||||
|
||||
It will show you the following message:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
Please provide --nn_model
|
||||
|
||||
This file implements decoding with an HLG decoding graph.
|
||||
|
||||
Usage:
|
||||
./bin/hlg_decode \
|
||||
--use_gpu true \
|
||||
--nn_model <path to torch scripted pt file> \
|
||||
--hlg <path to HLG.pt> \
|
||||
--word_table <path to words.txt> \
|
||||
<path to foo.wav> \
|
||||
<path to bar.wav> \
|
||||
<more waves if any>
|
||||
|
||||
To see all possible options, use
|
||||
./bin/hlg_decode --help
|
||||
|
||||
Caution:
|
||||
- Only sound files (*.wav) with single channel are supported.
|
||||
- It assumes the model is conformer_ctc/transformer.py from icefall.
|
||||
If you use a different model, you have to change the code
|
||||
related to `model.forward` in this file.
|
||||
|
||||
|
||||
HLG decoding
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./bin/hlg_decode \
|
||||
--use_gpu true \
|
||||
--nn_model icefall_asr_aishell_conformer_ctc/exp/cpu_jit.pt \
|
||||
--hlg icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt \
|
||||
--word_table icefall_asr_aishell_conformer_ctc/data/lang_char/words.txt \
|
||||
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav \
|
||||
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav \
|
||||
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav
|
||||
|
||||
The output is:
|
||||
|
||||
.. code-block::
|
||||
|
||||
2021-11-18 14:48:20.89 [I] k2/torch/bin/hlg_decode.cu:115:int main(int, char**) Device: cpu
|
||||
2021-11-18 14:48:20.89 [I] k2/torch/bin/hlg_decode.cu:124:int main(int, char**) Load wave files
|
||||
2021-11-18 14:48:20.97 [I] k2/torch/bin/hlg_decode.cu:131:int main(int, char**) Build Fbank computer
|
||||
2021-11-18 14:48:20.98 [I] k2/torch/bin/hlg_decode.cu:142:int main(int, char**) Compute features
|
||||
2021-11-18 14:48:20.115 [I] k2/torch/bin/hlg_decode.cu:150:int main(int, char**) Load neural network model
|
||||
2021-11-18 14:48:20.693 [I] k2/torch/bin/hlg_decode.cu:165:int main(int, char**) Compute nnet_output
|
||||
2021-11-18 14:48:23.182 [I] k2/torch/bin/hlg_decode.cu:180:int main(int, char**) Load icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt
|
||||
2021-11-18 14:48:33.489 [I] k2/torch/bin/hlg_decode.cu:185:int main(int, char**) Decoding
|
||||
2021-11-18 14:48:45.217 [I] k2/torch/bin/hlg_decode.cu:216:int main(int, char**)
|
||||
Decoding result:
|
||||
|
||||
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav
|
||||
甚至 出现 交易 几乎 停止 的 情况
|
||||
|
||||
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav
|
||||
一二 线 城市 虽然 也 处于 调整 中
|
||||
|
||||
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav
|
||||
但 因为 聚集 了 过多 公共 资源
|
||||
|
||||
There is a Colab notebook showing you how to run a torch scripted model in C++.
|
||||
Please see |aishell asr conformer ctc torch script colab notebook|
|
||||
|
||||
.. |aishell asr conformer ctc torch script colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
|
||||
:target: https://colab.research.google.com/drive/1Vh7RER7saTW01DtNbvr7CY7ovNZgmfWz?usp=sharing
|
||||
|
@ -38,14 +38,13 @@ from icefall.decode import (
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
)
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_env_info,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
@ -113,17 +112,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--export",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""When enabled, the averaged model is saved to
|
||||
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
||||
pretrained.pt contains a dict {"model": model.state_dict()},
|
||||
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -544,13 +532,6 @@ def main():
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
torch.save(
|
||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||
)
|
||||
return
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
|
165
egs/aishell/ASR/conformer_ctc/export.py
Normal file
165
egs/aishell/ASR/conformer_ctc/export.py
Normal file
@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=84,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=25,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="conformer_ctc/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="""It contains language related input files such as "lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"use_feat_batchnorm": True,
|
||||
"attention_dim": 512,
|
||||
"nhead": 4,
|
||||
"num_decoder_layers": 6,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.attention_dim,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=False,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torch.jit.script")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
98
egs/aishell/ASR/conformer_ctc/label_smoothing.py
Normal file
98
egs/aishell/ASR/conformer_ctc/label_smoothing.py
Normal file
@ -0,0 +1,98 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LabelSmoothingLoss(torch.nn.Module):
|
||||
"""
|
||||
Implement the LabelSmoothingLoss proposed in the following paper
|
||||
https://arxiv.org/pdf/1512.00567.pdf
|
||||
(Rethinking the Inception Architecture for Computer Vision)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ignore_index: int = -1,
|
||||
label_smoothing: float = 0.1,
|
||||
reduction: str = "sum",
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
ignore_index:
|
||||
ignored class id
|
||||
label_smoothing:
|
||||
smoothing rate (0.0 means the conventional cross entropy loss)
|
||||
reduction:
|
||||
It has the same meaning as the reduction in
|
||||
`torch.nn.CrossEntropyLoss`. It can be one of the following three
|
||||
values: (1) "none": No reduction will be applied. (2) "mean": the
|
||||
mean of the output is taken. (3) "sum": the output will be summed.
|
||||
"""
|
||||
super().__init__()
|
||||
assert 0.0 <= label_smoothing < 1.0
|
||||
self.ignore_index = ignore_index
|
||||
self.label_smoothing = label_smoothing
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute loss between x and target.
|
||||
|
||||
Args:
|
||||
x:
|
||||
prediction of dimension
|
||||
(batch_size, input_length, number_of_classes).
|
||||
target:
|
||||
target masked with self.ignore_index of
|
||||
dimension (batch_size, input_length).
|
||||
|
||||
Returns:
|
||||
A scalar tensor containing the loss without normalization.
|
||||
"""
|
||||
assert x.ndim == 3
|
||||
assert target.ndim == 2
|
||||
assert x.shape[:2] == target.shape
|
||||
num_classes = x.size(-1)
|
||||
x = x.reshape(-1, num_classes)
|
||||
# Now x is of shape (N*T, C)
|
||||
|
||||
# We don't want to change target in-place below,
|
||||
# so we make a copy of it here
|
||||
target = target.clone().reshape(-1)
|
||||
|
||||
ignored = target == self.ignore_index
|
||||
target[ignored] = 0
|
||||
|
||||
true_dist = torch.nn.functional.one_hot(
|
||||
target, num_classes=num_classes
|
||||
).to(x)
|
||||
|
||||
true_dist = (
|
||||
true_dist * (1 - self.label_smoothing)
|
||||
+ self.label_smoothing / num_classes
|
||||
)
|
||||
# Set the value of ignored indexes to 0
|
||||
true_dist[ignored] = 0
|
||||
|
||||
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
|
||||
if self.reduction == "sum":
|
||||
return loss.sum()
|
||||
elif self.reduction == "mean":
|
||||
return loss.sum() / (~ignored).sum()
|
||||
else:
|
||||
return loss.sum(dim=-1)
|
@ -34,7 +34,7 @@ from icefall.decode import (
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
)
|
||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -190,7 +190,6 @@ def get_params() -> AttributeDict:
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
@ -38,12 +38,12 @@ from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
get_env_info,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
@ -83,8 +84,8 @@ class Transformer(nn.Module):
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
|
||||
# self.encoder_embed converts the input of shape [N, T, num_classes]
|
||||
# to the shape [N, T//subsampling_factor, d_model].
|
||||
# self.encoder_embed converts the input of shape (N, T, num_classes)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_classes -> d_model
|
||||
@ -152,7 +153,7 @@ class Transformer(nn.Module):
|
||||
d_model, self.decoder_num_class
|
||||
)
|
||||
|
||||
self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class)
|
||||
self.decoder_criterion = LabelSmoothingLoss()
|
||||
else:
|
||||
self.decoder_criterion = None
|
||||
|
||||
@ -162,7 +163,7 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is [N, T, C].
|
||||
The input tensor. Its shape is (N, T, C).
|
||||
supervision:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
@ -171,17 +172,17 @@ class Transformer(nn.Module):
|
||||
|
||||
Returns:
|
||||
Return a tuple containing 3 tensors:
|
||||
- CTC output for ctc decoding. Its shape is [N, T, C]
|
||||
- Encoder output with shape [T, N, C]. It can be used as key and
|
||||
- CTC output for ctc decoding. Its shape is (N, T, C)
|
||||
- Encoder output with shape (T, N, C). It can be used as key and
|
||||
value for the decoder.
|
||||
- Encoder output padding mask. It can be used as
|
||||
memory_key_padding_mask for the decoder. Its shape is [N, T].
|
||||
memory_key_padding_mask for the decoder. Its shape is (N, T).
|
||||
It is None if `supervision` is None.
|
||||
"""
|
||||
if self.use_feat_batchnorm:
|
||||
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||
x, supervision
|
||||
)
|
||||
@ -195,7 +196,7 @@ class Transformer(nn.Module):
|
||||
|
||||
Args:
|
||||
x:
|
||||
The model input. Its shape is [N, T, C].
|
||||
The model input. Its shape is (N, T, C).
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
@ -206,8 +207,8 @@ class Transformer(nn.Module):
|
||||
padding mask for the decoder.
|
||||
Returns:
|
||||
Return a tuple with two tensors:
|
||||
- The encoder output, with shape [T, N, C]
|
||||
- encoder padding mask, with shape [N, T].
|
||||
- The encoder output, with shape (T, N, C)
|
||||
- encoder padding mask, with shape (N, T).
|
||||
The mask is None if `supervisions` is None.
|
||||
It is used as memory key padding mask in the decoder.
|
||||
"""
|
||||
@ -225,17 +226,18 @@ class Transformer(nn.Module):
|
||||
Args:
|
||||
x:
|
||||
The output tensor from the transformer encoder.
|
||||
Its shape is [T, N, C]
|
||||
Its shape is (T, N, C)
|
||||
|
||||
Returns:
|
||||
Return a tensor that can be used for CTC decoding.
|
||||
Its shape is [N, T, C]
|
||||
Its shape is (N, T, C)
|
||||
"""
|
||||
x = self.encoder_output_layer(x)
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
|
||||
return x
|
||||
|
||||
@torch.jit.export
|
||||
def decoder_forward(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
@ -247,7 +249,7 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
memory:
|
||||
It's the output of the encoder with shape [T, N, C]
|
||||
It's the output of the encoder with shape (T, N, C)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
token_ids:
|
||||
@ -264,11 +266,15 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||
ys_in_pad = pad_sequence(
|
||||
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||
)
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
|
||||
ys_out_pad = pad_sequence(
|
||||
ys_out, batch_first=True, padding_value=float(-1)
|
||||
)
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device)
|
||||
@ -301,18 +307,19 @@ class Transformer(nn.Module):
|
||||
|
||||
return decoder_loss
|
||||
|
||||
@torch.jit.export
|
||||
def decoder_nll(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
token_ids: List[List[int]],
|
||||
token_ids: List[torch.Tensor],
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
memory:
|
||||
It's the output of the encoder with shape [T, N, C]
|
||||
It's the output of the encoder with shape (T, N, C)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
token_ids:
|
||||
@ -328,14 +335,23 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
# The common part between this function and decoder_forward could be
|
||||
# extracted as a separate function.
|
||||
if isinstance(token_ids[0], torch.Tensor):
|
||||
# This branch is executed by torchscript in C++.
|
||||
# See https://github.com/k2-fsa/k2/pull/870
|
||||
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
|
||||
token_ids = [tolist(t) for t in token_ids]
|
||||
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||
ys_in_pad = pad_sequence(
|
||||
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||
)
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
|
||||
ys_out_pad = pad_sequence(
|
||||
ys_out, batch_first=True, padding_value=float(-1)
|
||||
)
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
||||
@ -649,25 +665,25 @@ class PositionalEncoding(nn.Module):
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.pe = None
|
||||
# not doing: self.pe = None because of errors thrown by torchscript
|
||||
self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
|
||||
|
||||
def extend_pe(self, x: torch.Tensor) -> None:
|
||||
"""Extend the time t in the positional encoding if required.
|
||||
|
||||
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x
|
||||
is [N, T, d_model]. If T > T1, then we change the shape of self.pe
|
||||
to [N, T, d_model]. Otherwise, nothing is done.
|
||||
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
|
||||
is (N, T, d_model). If T > T1, then we change the shape of self.pe
|
||||
to (N, T, d_model). Otherwise, nothing is done.
|
||||
|
||||
Args:
|
||||
x:
|
||||
It is a tensor of shape [N, T, C].
|
||||
It is a tensor of shape (N, T, C).
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= x.size(1):
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
@ -678,7 +694,7 @@ class PositionalEncoding(nn.Module):
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
# Now pe is of shape [1, T, d_model], where T is x.size(1)
|
||||
# Now pe is of shape (1, T, d_model), where T is x.size(1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -687,10 +703,10 @@ class PositionalEncoding(nn.Module):
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, C]
|
||||
Its shape is (N, T, C)
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, T, C]
|
||||
Return a tensor of shape (N, T, C)
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale + self.pe[:, : x.size(1), :]
|
||||
@ -784,73 +800,6 @@ class Noam(object):
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class LabelSmoothingLoss(nn.Module):
|
||||
"""
|
||||
Label-smoothing loss. KL-divergence between
|
||||
q_{smoothed ground truth prob.}(w)
|
||||
and p_{prob. computed by model}(w) is minimized.
|
||||
Modified from
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
|
||||
|
||||
Args:
|
||||
size: the number of class
|
||||
padding_idx: padding_idx: ignored class id
|
||||
smoothing: smoothing rate (0.0 means the conventional CE)
|
||||
normalize_length: normalize loss by sequence length if True
|
||||
criterion: loss function to be smoothed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
padding_idx: int = -1,
|
||||
smoothing: float = 0.1,
|
||||
normalize_length: bool = False,
|
||||
criterion: nn.Module = nn.KLDivLoss(reduction="none"),
|
||||
) -> None:
|
||||
"""Construct an LabelSmoothingLoss object."""
|
||||
super(LabelSmoothingLoss, self).__init__()
|
||||
self.criterion = criterion
|
||||
self.padding_idx = padding_idx
|
||||
assert 0.0 < smoothing <= 1.0
|
||||
self.confidence = 1.0 - smoothing
|
||||
self.smoothing = smoothing
|
||||
self.size = size
|
||||
self.true_dist = None
|
||||
self.normalize_length = normalize_length
|
||||
|
||||
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute loss between x and target.
|
||||
|
||||
Args:
|
||||
x:
|
||||
prediction of dimension
|
||||
(batch_size, input_length, number_of_classes).
|
||||
target:
|
||||
target masked with self.padding_id of
|
||||
dimension (batch_size, input_length).
|
||||
|
||||
Returns:
|
||||
A scalar tensor containing the loss without normalization.
|
||||
"""
|
||||
assert x.size(2) == self.size
|
||||
# batch_size = x.size(0)
|
||||
x = x.view(-1, self.size)
|
||||
target = target.view(-1)
|
||||
with torch.no_grad():
|
||||
true_dist = x.clone()
|
||||
true_dist.fill_(self.smoothing / (self.size - 1))
|
||||
ignore = target == self.padding_idx # (B,)
|
||||
total = len(target) - ignore.sum().item()
|
||||
target = target.masked_fill(ignore, 0) # avoid -1 index
|
||||
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
||||
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
||||
# denom = total if self.normalize_length else batch_size
|
||||
denom = total if self.normalize_length else 1
|
||||
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
||||
|
||||
|
||||
def encoder_padding_mask(
|
||||
max_len: int, supervisions: Optional[Supervisions] = None
|
||||
) -> Optional[torch.Tensor]:
|
||||
@ -972,10 +921,7 @@ def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
|
||||
Return a new list-of-list, where each sublist starts
|
||||
with SOS ID.
|
||||
"""
|
||||
ans = []
|
||||
for utt in token_ids:
|
||||
ans.append([sos_id] + utt)
|
||||
return ans
|
||||
return [[sos_id] + utt for utt in token_ids]
|
||||
|
||||
|
||||
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
||||
@ -992,7 +938,9 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
||||
Return a new list-of-list, where each sublist ends
|
||||
with EOS ID.
|
||||
"""
|
||||
ans = []
|
||||
for utt in token_ids:
|
||||
ans.append(utt + [eos_id])
|
||||
return ans
|
||||
return [utt + [eos_id] for utt in token_ids]
|
||||
|
||||
|
||||
def tolist(t: torch.Tensor) -> List[int]:
|
||||
"""Used by jit"""
|
||||
return torch.jit.annotate(List[int], t.tolist())
|
||||
|
@ -28,12 +28,12 @@ from conformer import Conformer
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
get_alignments,
|
||||
get_env_info,
|
||||
save_alignments,
|
||||
setup_logger,
|
||||
)
|
||||
|
@ -40,14 +40,13 @@ from icefall.decode import (
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_env_info,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
@ -122,17 +121,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--export",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""When enabled, the averaged model is saved to
|
||||
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
||||
pretrained.pt contains a dict {"model": model.state_dict()},
|
||||
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -671,13 +659,6 @@ def main():
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
torch.save(
|
||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||
)
|
||||
return
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
|
@ -36,7 +36,7 @@ from icefall.decode import (
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -256,7 +256,6 @@ def main():
|
||||
params.num_decoder_layers = 0
|
||||
|
||||
params.update(vars(args))
|
||||
params["env_info"] = get_env_info()
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
@ -41,12 +41,12 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
get_env_info,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
@ -36,10 +36,10 @@ from icefall.decode import (
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_env_info,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
|
@ -34,7 +34,7 @@ from icefall.decode import (
|
||||
one_best_decoding,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -159,7 +159,6 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params["env_info"] = get_env_info()
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
get_env_info,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
@ -34,7 +34,7 @@ from icefall.decode import (
|
||||
one_best_decoding,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -159,7 +159,6 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params["env_info"] = get_env_info()
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
get_env_info,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
@ -34,7 +34,7 @@ from icefall.decode import (
|
||||
one_best_decoding,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -159,7 +159,6 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params["env_info"] = get_env_info()
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
get_env_info,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
@ -14,10 +14,10 @@ from model import Tdnn
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_env_info,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
|
@ -29,7 +29,7 @@ from model import Tdnn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -116,7 +116,6 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params["env_info"] = get_env_info()
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
@ -22,15 +22,10 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_env_info,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
106
icefall/env.py
Normal file
106
icefall/env.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import k2
|
||||
import k2.version
|
||||
import lhotse
|
||||
import torch
|
||||
|
||||
|
||||
def get_git_sha1():
|
||||
git_commit = (
|
||||
subprocess.run(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
dirty_commit = (
|
||||
len(
|
||||
subprocess.run(
|
||||
["git", "diff", "--shortstat"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
> 0
|
||||
)
|
||||
git_commit = (
|
||||
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
|
||||
)
|
||||
return git_commit
|
||||
|
||||
|
||||
def get_git_date():
|
||||
git_date = (
|
||||
subprocess.run(
|
||||
["git", "log", "-1", "--format=%ad", "--date=local"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
return git_date
|
||||
|
||||
|
||||
def get_git_branch_name():
|
||||
git_date = (
|
||||
subprocess.run(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
return git_date
|
||||
|
||||
|
||||
def get_env_info() -> Dict[str, Any]:
|
||||
"""Get the environment information."""
|
||||
return {
|
||||
"k2-version": k2.version.__version__,
|
||||
"k2-build-type": k2.version.__build_type__,
|
||||
"k2-with-cuda": k2.with_cuda,
|
||||
"k2-git-sha1": k2.version.__git_sha1__,
|
||||
"k2-git-date": k2.version.__git_date__,
|
||||
"lhotse-version": lhotse.__version__,
|
||||
"torch-cuda-available": torch.cuda.is_available(),
|
||||
"torch-cuda-version": torch.version.cuda,
|
||||
"python-version": sys.version[:3],
|
||||
"icefall-git-branch": get_git_branch_name(),
|
||||
"icefall-git-sha1": get_git_sha1(),
|
||||
"icefall-git-date": get_git_date(),
|
||||
"icefall-path": str(Path(__file__).resolve().parent.parent),
|
||||
"k2-path": str(Path(k2.__file__).resolve()),
|
||||
"lhotse-path": str(Path(lhotse.__file__).resolve()),
|
||||
}
|
@ -21,17 +21,15 @@ import collections
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union
|
||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||
|
||||
import k2
|
||||
import k2.version
|
||||
import kaldialign
|
||||
import lhotse
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@ -137,85 +135,6 @@ def setup_logger(
|
||||
logging.getLogger("").addHandler(console)
|
||||
|
||||
|
||||
def get_git_sha1():
|
||||
git_commit = (
|
||||
subprocess.run(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
dirty_commit = (
|
||||
len(
|
||||
subprocess.run(
|
||||
["git", "diff", "--shortstat"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
> 0
|
||||
)
|
||||
git_commit = (
|
||||
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
|
||||
)
|
||||
return git_commit
|
||||
|
||||
|
||||
def get_git_date():
|
||||
git_date = (
|
||||
subprocess.run(
|
||||
["git", "log", "-1", "--format=%ad", "--date=local"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
return git_date
|
||||
|
||||
|
||||
def get_git_branch_name():
|
||||
git_date = (
|
||||
subprocess.run(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
.stdout.decode()
|
||||
.rstrip("\n")
|
||||
.strip()
|
||||
)
|
||||
return git_date
|
||||
|
||||
|
||||
def get_env_info() -> Dict[str, Any]:
|
||||
"""Get the environment information."""
|
||||
return {
|
||||
"k2-version": k2.version.__version__,
|
||||
"k2-build-type": k2.version.__build_type__,
|
||||
"k2-with-cuda": k2.with_cuda,
|
||||
"k2-git-sha1": k2.version.__git_sha1__,
|
||||
"k2-git-date": k2.version.__git_date__,
|
||||
"lhotse-version": lhotse.__version__,
|
||||
"torch-cuda-available": torch.cuda.is_available(),
|
||||
"torch-cuda-version": torch.version.cuda,
|
||||
"python-version": sys.version[:3],
|
||||
"icefall-git-branch": get_git_branch_name(),
|
||||
"icefall-git-sha1": get_git_sha1(),
|
||||
"icefall-git-date": get_git_date(),
|
||||
"icefall-path": str(Path(__file__).resolve().parent.parent),
|
||||
"k2-path": str(Path(k2.__file__).resolve()),
|
||||
"lhotse-path": str(Path(lhotse.__file__).resolve()),
|
||||
}
|
||||
|
||||
|
||||
class AttributeDict(dict):
|
||||
def __getattr__(self, key):
|
||||
if key in self:
|
||||
|
@ -20,12 +20,8 @@ import k2
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
get_env_info,
|
||||
get_texts,
|
||||
)
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, encode_supervisions, get_texts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
Loading…
x
Reference in New Issue
Block a user