Add torch script support for Aishell and update documents (#124)

* Add aishell recipe

* Remove unnecessary code and update docs

* adapt to k2 v1.7, add docs and results

* Update conformer ctc model

* Update docs, pretrained.py & results

* Fix code style

* Fix code style

* Fix code style

* Minor fix

* Minor fix

* Fix pretrained.py

* Update pretrained model & corresponding docs

* Export torch script model for Aishell

* Add C++ deployment docs

* Minor fixes

* Fix unit test

* Update Readme
This commit is contained in:
Wei Kang 2021-11-19 16:37:05 +08:00 committed by GitHub
parent 30c43b7f69
commit 4151cca147
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 597 additions and 263 deletions

View File

@ -12,10 +12,11 @@ for installation.
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/index.html> Please refer to <https://icefall.readthedocs.io/en/latest/recipes/index.html>
for more information. for more information.
We provide three recipes at present: We provide four recipes at present:
- [yesno][yesno] - [yesno][yesno]
- [LibriSpeech][librispeech] - [LibriSpeech][librispeech]
- [Aishell][aishell]
- [TIMIT][timit] - [TIMIT][timit]
### yesno ### yesno
@ -57,6 +58,31 @@ The WER for this model is:
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
### Aishell
We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc]
and [TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc].
#### Conformer CTC Model
The best CER we currently have is:
| | test |
|-----|------|
| CER | 4.26 |
We provide a Colab notebook to run a pre-trained conformer CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing)
#### TDNN LSTM CTC Model
The CER for this model is:
| | test |
|-----|-------|
| CER | 10.16 |
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z?usp=sharing)
### TIMIT ### TIMIT
@ -99,9 +125,12 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc [LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc
[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc [LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc
[Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc
[Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc
[TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc [TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc
[TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc [TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc
[yesno]: egs/yesno/ASR [yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR [librispeech]: egs/librispeech/ASR
[aishell]: egs/aishell/ASR
[timit]: egs/timit/ASR [timit]: egs/timit/ASR
[k2]: https://github.com/k2-fsa/k2 [k2]: https://github.com/k2-fsa/k2

View File

@ -18,7 +18,7 @@ In this tutorial, you will learn:
- (1) How to prepare data for training and decoding - (1) How to prepare data for training and decoding
- (2) How to start the training, either with a single GPU or multiple GPUs - (2) How to start the training, either with a single GPU or multiple GPUs
- (3) How to do decoding after training, with 1best and attention decoder rescoring - (3) How to do decoding after training, with ctc-decoding, 1best and attention decoder rescoring
- (4) How to use a pre-trained model, provided by us - (4) How to use a pre-trained model, provided by us
Data preparation Data preparation
@ -623,3 +623,125 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained
**Congratulations!** You have finished the aishell ASR recipe with **Congratulations!** You have finished the aishell ASR recipe with
conformer CTC models in ``icefall``. conformer CTC models in ``icefall``.
If you want to deploy your trained model in C++, please read the following section.
Deployment with C++
-------------------
This section describes how to deploy the pre-trained model in C++, without
Python dependencies.
.. HINT::
At present, it does NOT support streaming decoding.
First, let us compile k2 from source:
.. code-block:: bash
$ cd $HOME
$ git clone https://github.com/k2-fsa/k2
$ cd k2
$ git checkout v2.0-pre
.. CAUTION::
You have to switch to the branch ``v2.0-pre``!
.. code-block:: bash
$ mkdir build-release
$ cd build-release
$ cmake -DCMAKE_BUILD_TYPE=Release ..
$ make -j hlg_decode
# You will find four binaries in `./bin`, i.e. ./bin/hlg_decode,
Now you are ready to go!
Assume you have run:
.. code-block:: bash
$ cd k2/build-release
$ ln -s /path/to/icefall-asr-aishell-conformer-ctc ./
To view the usage of ``./bin/hlg_decode``, run:
.. code-block::
$ ./bin/hlg_decode
It will show you the following message:
.. code-block:: bash
Please provide --nn_model
This file implements decoding with an HLG decoding graph.
Usage:
./bin/hlg_decode \
--use_gpu true \
--nn_model <path to torch scripted pt file> \
--hlg <path to HLG.pt> \
--word_table <path to words.txt> \
<path to foo.wav> \
<path to bar.wav> \
<more waves if any>
To see all possible options, use
./bin/hlg_decode --help
Caution:
- Only sound files (*.wav) with single channel are supported.
- It assumes the model is conformer_ctc/transformer.py from icefall.
If you use a different model, you have to change the code
related to `model.forward` in this file.
HLG decoding
^^^^^^^^^^^^
.. code-block:: bash
./bin/hlg_decode \
--use_gpu true \
--nn_model icefall_asr_aishell_conformer_ctc/exp/cpu_jit.pt \
--hlg icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt \
--word_table icefall_asr_aishell_conformer_ctc/data/lang_char/words.txt \
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav \
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav \
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav
The output is:
.. code-block::
2021-11-18 14:48:20.89 [I] k2/torch/bin/hlg_decode.cu:115:int main(int, char**) Device: cpu
2021-11-18 14:48:20.89 [I] k2/torch/bin/hlg_decode.cu:124:int main(int, char**) Load wave files
2021-11-18 14:48:20.97 [I] k2/torch/bin/hlg_decode.cu:131:int main(int, char**) Build Fbank computer
2021-11-18 14:48:20.98 [I] k2/torch/bin/hlg_decode.cu:142:int main(int, char**) Compute features
2021-11-18 14:48:20.115 [I] k2/torch/bin/hlg_decode.cu:150:int main(int, char**) Load neural network model
2021-11-18 14:48:20.693 [I] k2/torch/bin/hlg_decode.cu:165:int main(int, char**) Compute nnet_output
2021-11-18 14:48:23.182 [I] k2/torch/bin/hlg_decode.cu:180:int main(int, char**) Load icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt
2021-11-18 14:48:33.489 [I] k2/torch/bin/hlg_decode.cu:185:int main(int, char**) Decoding
2021-11-18 14:48:45.217 [I] k2/torch/bin/hlg_decode.cu:216:int main(int, char**)
Decoding result:
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav
甚至 出现 交易 几乎 停止 的 情况
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav
一二 线 城市 虽然 也 处于 调整 中
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav
但 因为 聚集 了 过多 公共 资源
There is a Colab notebook showing you how to run a torch scripted model in C++.
Please see |aishell asr conformer ctc torch script colab notebook|
.. |aishell asr conformer ctc torch script colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1Vh7RER7saTW01DtNbvr7CY7ovNZgmfWz?usp=sharing

View File

@ -38,14 +38,13 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
) )
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -113,17 +112,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -544,13 +532,6 @@ def main():
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])

View File

@ -0,0 +1,165 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
import argparse
import logging
from pathlib import Path
import torch
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=84,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=25,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""It contains language related input files such as "lexicon.txt"
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=True,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"subsampling_factor": 4,
"use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 4,
"num_decoder_layers": 6,
}
)
return params
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm,
)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,98 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class LabelSmoothingLoss(torch.nn.Module):
"""
Implement the LabelSmoothingLoss proposed in the following paper
https://arxiv.org/pdf/1512.00567.pdf
(Rethinking the Inception Architecture for Computer Vision)
"""
def __init__(
self,
ignore_index: int = -1,
label_smoothing: float = 0.1,
reduction: str = "sum",
) -> None:
"""
Args:
ignore_index:
ignored class id
label_smoothing:
smoothing rate (0.0 means the conventional cross entropy loss)
reduction:
It has the same meaning as the reduction in
`torch.nn.CrossEntropyLoss`. It can be one of the following three
values: (1) "none": No reduction will be applied. (2) "mean": the
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.ignore_index of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.ndim == 3
assert target.ndim == 2
assert x.shape[:2] == target.shape
num_classes = x.size(-1)
x = x.reshape(-1, num_classes)
# Now x is of shape (N*T, C)
# We don't want to change target in-place below,
# so we make a copy of it here
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
target[ignored] = 0
true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes
).to(x)
true_dist = (
true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
true_dist[ignored] = 0
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":
return loss.sum()
elif self.reduction == "mean":
return loss.sum() / (~ignored).sum()
else:
return loss.sum(dim=-1)

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
) )
from icefall.utils import AttributeDict, get_env_info, get_texts from icefall.utils import AttributeDict, get_texts
def get_parser(): def get_parser():
@ -190,7 +190,6 @@ def get_params() -> AttributeDict:
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
"env_info": get_env_info(),
} }
) )
return params return params

View File

@ -38,12 +38,12 @@ from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )

View File

@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
@ -83,8 +84,8 @@ class Transformer(nn.Module):
if subsampling_factor != 4: if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.") raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape [N, T, num_classes] # self.encoder_embed converts the input of shape (N, T, num_classes)
# to the shape [N, T//subsampling_factor, d_model]. # to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously: # That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor # (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_classes -> d_model # (2) embedding: num_classes -> d_model
@ -152,7 +153,7 @@ class Transformer(nn.Module):
d_model, self.decoder_num_class d_model, self.decoder_num_class
) )
self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss()
else: else:
self.decoder_criterion = None self.decoder_criterion = None
@ -162,7 +163,7 @@ class Transformer(nn.Module):
""" """
Args: Args:
x: x:
The input tensor. Its shape is [N, T, C]. The input tensor. Its shape is (N, T, C).
supervision: supervision:
Supervision in lhotse format. Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@ -171,17 +172,17 @@ class Transformer(nn.Module):
Returns: Returns:
Return a tuple containing 3 tensors: Return a tuple containing 3 tensors:
- CTC output for ctc decoding. Its shape is [N, T, C] - CTC output for ctc decoding. Its shape is (N, T, C)
- Encoder output with shape [T, N, C]. It can be used as key and - Encoder output with shape (T, N, C). It can be used as key and
value for the decoder. value for the decoder.
- Encoder output padding mask. It can be used as - Encoder output padding mask. It can be used as
memory_key_padding_mask for the decoder. Its shape is [N, T]. memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None. It is None if `supervision` is None.
""" """
if self.use_feat_batchnorm: if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder( encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision x, supervision
) )
@ -195,7 +196,7 @@ class Transformer(nn.Module):
Args: Args:
x: x:
The model input. Its shape is [N, T, C]. The model input. Its shape is (N, T, C).
supervisions: supervisions:
Supervision in lhotse format. Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@ -206,8 +207,8 @@ class Transformer(nn.Module):
padding mask for the decoder. padding mask for the decoder.
Returns: Returns:
Return a tuple with two tensors: Return a tuple with two tensors:
- The encoder output, with shape [T, N, C] - The encoder output, with shape (T, N, C)
- encoder padding mask, with shape [N, T]. - encoder padding mask, with shape (N, T).
The mask is None if `supervisions` is None. The mask is None if `supervisions` is None.
It is used as memory key padding mask in the decoder. It is used as memory key padding mask in the decoder.
""" """
@ -225,17 +226,18 @@ class Transformer(nn.Module):
Args: Args:
x: x:
The output tensor from the transformer encoder. The output tensor from the transformer encoder.
Its shape is [T, N, C] Its shape is (T, N, C)
Returns: Returns:
Return a tensor that can be used for CTC decoding. Return a tensor that can be used for CTC decoding.
Its shape is [N, T, C] Its shape is (N, T, C)
""" """
x = self.encoder_output_layer(x) x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
return x return x
@torch.jit.export
def decoder_forward( def decoder_forward(
self, self,
memory: torch.Tensor, memory: torch.Tensor,
@ -247,7 +249,7 @@ class Transformer(nn.Module):
""" """
Args: Args:
memory: memory:
It's the output of the encoder with shape [T, N, C] It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder. The padding mask from the encoder.
token_ids: token_ids:
@ -264,11 +266,15 @@ class Transformer(nn.Module):
""" """
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device) ys_in_pad = ys_in_pad.to(device)
@ -301,18 +307,19 @@ class Transformer(nn.Module):
return decoder_loss return decoder_loss
@torch.jit.export
def decoder_nll( def decoder_nll(
self, self,
memory: torch.Tensor, memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor,
token_ids: List[List[int]], token_ids: List[torch.Tensor],
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
memory: memory:
It's the output of the encoder with shape [T, N, C] It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder. The padding mask from the encoder.
token_ids: token_ids:
@ -328,14 +335,23 @@ class Transformer(nn.Module):
""" """
# The common part between this function and decoder_forward could be # The common part between this function and decoder_forward could be
# extracted as a separate function. # extracted as a separate function.
if isinstance(token_ids[0], torch.Tensor):
# This branch is executed by torchscript in C++.
# See https://github.com/k2-fsa/k2/pull/870
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
token_ids = [tolist(t) for t in token_ids]
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
@ -649,24 +665,24 @@ class PositionalEncoding(nn.Module):
self.d_model = d_model self.d_model = d_model
self.xscale = math.sqrt(self.d_model) self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
self.pe = None # not doing: self.pe = None because of errors thrown by torchscript
self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
def extend_pe(self, x: torch.Tensor) -> None: def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required. """Extend the time t in the positional encoding if required.
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x The shape of `self.pe` is (1, T1, d_model). The shape of the input x
is [N, T, d_model]. If T > T1, then we change the shape of self.pe is (N, T, d_model). If T > T1, then we change the shape of self.pe
to [N, T, d_model]. Otherwise, nothing is done. to (N, T, d_model). Otherwise, nothing is done.
Args: Args:
x: x:
It is a tensor of shape [N, T, C]. It is a tensor of shape (N, T, C).
Returns: Returns:
Return None. Return None.
""" """
if self.pe is not None: if self.pe is not None:
if self.pe.size(1) >= x.size(1): if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
@ -678,7 +694,7 @@ class PositionalEncoding(nn.Module):
pe[:, 0::2] = torch.sin(position * div_term) pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term) pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) pe = pe.unsqueeze(0)
# Now pe is of shape [1, T, d_model], where T is x.size(1) # Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype) self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -687,10 +703,10 @@ class PositionalEncoding(nn.Module):
Args: Args:
x: x:
Its shape is [N, T, C] Its shape is (N, T, C)
Returns: Returns:
Return a tensor of shape [N, T, C] Return a tensor of shape (N, T, C)
""" """
self.extend_pe(x) self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :] x = x * self.xscale + self.pe[:, : x.size(1), :]
@ -784,73 +800,6 @@ class Noam(object):
setattr(self, key, value) setattr(self, key, value)
class LabelSmoothingLoss(nn.Module):
"""
Label-smoothing loss. KL-divergence between
q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized.
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
Args:
size: the number of class
padding_idx: padding_idx: ignored class id
smoothing: smoothing rate (0.0 means the conventional CE)
normalize_length: normalize loss by sequence length if True
criterion: loss function to be smoothed
"""
def __init__(
self,
size: int,
padding_idx: int = -1,
smoothing: float = 0.1,
normalize_length: bool = False,
criterion: nn.Module = nn.KLDivLoss(reduction="none"),
) -> None:
"""Construct an LabelSmoothingLoss object."""
super(LabelSmoothingLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
assert 0.0 < smoothing <= 1.0
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.padding_id of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.size(2) == self.size
# batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
# denom = total if self.normalize_length else batch_size
denom = total if self.normalize_length else 1
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
def encoder_padding_mask( def encoder_padding_mask(
max_len: int, supervisions: Optional[Supervisions] = None max_len: int, supervisions: Optional[Supervisions] = None
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
@ -972,10 +921,7 @@ def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
Return a new list-of-list, where each sublist starts Return a new list-of-list, where each sublist starts
with SOS ID. with SOS ID.
""" """
ans = [] return [[sos_id] + utt for utt in token_ids]
for utt in token_ids:
ans.append([sos_id] + utt)
return ans
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
@ -992,7 +938,9 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
Return a new list-of-list, where each sublist ends Return a new list-of-list, where each sublist ends
with EOS ID. with EOS ID.
""" """
ans = [] return [utt + [eos_id] for utt in token_ids]
for utt in token_ids:
ans.append(utt + [eos_id])
return ans def tolist(t: torch.Tensor) -> List[int]:
"""Used by jit"""
return torch.jit.annotate(List[int], t.tolist())

View File

@ -28,12 +28,12 @@ from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import one_best_decoding from icefall.decode import one_best_decoding
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
encode_supervisions, encode_supervisions,
get_alignments, get_alignments,
get_env_info,
save_alignments, save_alignments,
setup_logger, setup_logger,
) )

View File

@ -40,14 +40,13 @@ from icefall.decode import (
rescore_with_n_best_list, rescore_with_n_best_list,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -122,17 +121,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -671,13 +659,6 @@ def main():
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])

View File

@ -36,7 +36,7 @@ from icefall.decode import (
rescore_with_attention_decoder, rescore_with_attention_decoder,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.utils import AttributeDict, get_env_info, get_texts from icefall.utils import AttributeDict, get_texts
def get_parser(): def get_parser():
@ -256,7 +256,6 @@ def main():
params.num_decoder_layers = 0 params.num_decoder_layers = 0
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -41,12 +41,12 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )

View File

@ -36,10 +36,10 @@ from icefall.decode import (
rescore_with_n_best_list, rescore_with_n_best_list,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.utils import AttributeDict, get_env_info, get_texts from icefall.utils import AttributeDict, get_texts
def get_parser(): def get_parser():
@ -159,7 +159,6 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.utils import AttributeDict, get_env_info, get_texts from icefall.utils import AttributeDict, get_texts
def get_parser(): def get_parser():
@ -159,7 +159,6 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.utils import AttributeDict, get_env_info, get_texts from icefall.utils import AttributeDict, get_texts
def get_parser(): def get_parser():
@ -159,7 +159,6 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )

View File

@ -14,10 +14,10 @@ from model import Tdnn
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice, one_best_decoding from icefall.decode import get_lattice, one_best_decoding
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,

View File

@ -29,7 +29,7 @@ from model import Tdnn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_env_info, get_texts from icefall.utils import AttributeDict, get_texts
def get_parser(): def get_parser():
@ -116,7 +116,6 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -22,15 +22,10 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
AttributeDict,
MetricsTracker,
get_env_info,
setup_logger,
str2bool,
)
def get_parser(): def get_parser():

106
icefall/env.py Normal file
View File

@ -0,0 +1,106 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Wei Kang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import sys
from pathlib import Path
from typing import Any, Dict
import k2
import k2.version
import lhotse
import torch
def get_git_sha1():
git_commit = (
subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
dirty_commit = (
len(
subprocess.run(
["git", "diff", "--shortstat"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
> 0
)
git_commit = (
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
)
return git_commit
def get_git_date():
git_date = (
subprocess.run(
["git", "log", "-1", "--format=%ad", "--date=local"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_git_branch_name():
git_date = (
subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_env_info() -> Dict[str, Any]:
"""Get the environment information."""
return {
"k2-version": k2.version.__version__,
"k2-build-type": k2.version.__build_type__,
"k2-with-cuda": k2.with_cuda,
"k2-git-sha1": k2.version.__git_sha1__,
"k2-git-date": k2.version.__git_date__,
"lhotse-version": lhotse.__version__,
"torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3],
"icefall-git-branch": get_git_branch_name(),
"icefall-git-sha1": get_git_sha1(),
"icefall-git-date": get_git_date(),
"icefall-path": str(Path(__file__).resolve().parent.parent),
"k2-path": str(Path(k2.__file__).resolve()),
"lhotse-path": str(Path(lhotse.__file__).resolve()),
}

View File

@ -21,17 +21,15 @@ import collections
import logging import logging
import os import os
import subprocess import subprocess
import sys
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2 import k2
import k2.version import k2.version
import kaldialign import kaldialign
import lhotse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -137,85 +135,6 @@ def setup_logger(
logging.getLogger("").addHandler(console) logging.getLogger("").addHandler(console)
def get_git_sha1():
git_commit = (
subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
dirty_commit = (
len(
subprocess.run(
["git", "diff", "--shortstat"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
> 0
)
git_commit = (
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
)
return git_commit
def get_git_date():
git_date = (
subprocess.run(
["git", "log", "-1", "--format=%ad", "--date=local"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_git_branch_name():
git_date = (
subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_env_info() -> Dict[str, Any]:
"""Get the environment information."""
return {
"k2-version": k2.version.__version__,
"k2-build-type": k2.version.__build_type__,
"k2-with-cuda": k2.with_cuda,
"k2-git-sha1": k2.version.__git_sha1__,
"k2-git-date": k2.version.__git_date__,
"lhotse-version": lhotse.__version__,
"torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3],
"icefall-git-branch": get_git_branch_name(),
"icefall-git-sha1": get_git_sha1(),
"icefall-git-date": get_git_date(),
"icefall-path": str(Path(__file__).resolve().parent.parent),
"k2-path": str(Path(k2.__file__).resolve()),
"lhotse-path": str(Path(lhotse.__file__).resolve()),
}
class AttributeDict(dict): class AttributeDict(dict):
def __getattr__(self, key): def __getattr__(self, key):
if key in self: if key in self:

View File

@ -20,12 +20,8 @@ import k2
import pytest import pytest
import torch import torch
from icefall.utils import ( from icefall.env import get_env_info
AttributeDict, from icefall.utils import AttributeDict, encode_supervisions, get_texts
encode_supervisions,
get_env_info,
get_texts,
)
@pytest.fixture @pytest.fixture