From 4151cca147082586dbae9d4899bed12ba2fa9e1d Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Fri, 19 Nov 2021 16:37:05 +0800 Subject: [PATCH] Add torch script support for Aishell and update documents (#124) * Add aishell recipe * Remove unnecessary code and update docs * adapt to k2 v1.7, add docs and results * Update conformer ctc model * Update docs, pretrained.py & results * Fix code style * Fix code style * Fix code style * Minor fix * Minor fix * Fix pretrained.py * Update pretrained model & corresponding docs * Export torch script model for Aishell * Add C++ deployment docs * Minor fixes * Fix unit test * Update Readme --- README.md | 31 +++- docs/source/recipes/aishell/conformer_ctc.rst | 124 ++++++++++++- egs/aishell/ASR/conformer_ctc/decode.py | 21 +-- egs/aishell/ASR/conformer_ctc/export.py | 165 ++++++++++++++++++ .../ASR/conformer_ctc/label_smoothing.py | 98 +++++++++++ egs/aishell/ASR/conformer_ctc/pretrained.py | 3 +- egs/aishell/ASR/conformer_ctc/train.py | 2 +- egs/aishell/ASR/conformer_ctc/transformer.py | 160 ++++++----------- egs/librispeech/ASR/conformer_ctc/ali.py | 2 +- egs/librispeech/ASR/conformer_ctc/decode.py | 21 +-- .../ASR/conformer_ctc/pretrained.py | 3 +- egs/librispeech/ASR/conformer_ctc/train.py | 2 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 2 +- .../ASR/tdnn_lstm_ctc/pretrained.py | 3 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 2 +- egs/timit/ASR/tdnn_ligru_ctc/pretrained.py | 3 +- egs/timit/ASR/tdnn_ligru_ctc/train.py | 2 +- egs/timit/ASR/tdnn_lstm_ctc/pretrained.py | 3 +- egs/timit/ASR/tdnn_lstm_ctc/train.py | 2 +- egs/yesno/ASR/tdnn/decode.py | 2 +- egs/yesno/ASR/tdnn/pretrained.py | 3 +- egs/yesno/ASR/tdnn/train.py | 9 +- icefall/env.py | 106 +++++++++++ icefall/utils.py | 83 +-------- test/test_utils.py | 8 +- 25 files changed, 597 insertions(+), 263 deletions(-) create mode 100644 egs/aishell/ASR/conformer_ctc/export.py create mode 100644 egs/aishell/ASR/conformer_ctc/label_smoothing.py create mode 100644 icefall/env.py diff --git a/README.md b/README.md index 140d07645..707ed09d0 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,11 @@ for installation. Please refer to for more information. -We provide three recipes at present: +We provide four recipes at present: - [yesno][yesno] - [LibriSpeech][librispeech] + - [Aishell][aishell] - [TIMIT][timit] ### yesno @@ -57,6 +58,31 @@ The WER for this model is: We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) +### Aishell + +We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc] +and [TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc]. + +#### Conformer CTC Model + +The best CER we currently have is: + +| | test | +|-----|------| +| CER | 4.26 | + + +We provide a Colab notebook to run a pre-trained conformer CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing) + +#### TDNN LSTM CTC Model + +The CER for this model is: + +| | test | +|-----|-------| +| CER | 10.16 | + +We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z?usp=sharing) ### TIMIT @@ -99,9 +125,12 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc [LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc +[Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc +[Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc [TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc [TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc [yesno]: egs/yesno/ASR [librispeech]: egs/librispeech/ASR +[aishell]: egs/aishell/ASR [timit]: egs/timit/ASR [k2]: https://github.com/k2-fsa/k2 diff --git a/docs/source/recipes/aishell/conformer_ctc.rst b/docs/source/recipes/aishell/conformer_ctc.rst index 20967780a..c7fd91e99 100644 --- a/docs/source/recipes/aishell/conformer_ctc.rst +++ b/docs/source/recipes/aishell/conformer_ctc.rst @@ -18,7 +18,7 @@ In this tutorial, you will learn: - (1) How to prepare data for training and decoding - (2) How to start the training, either with a single GPU or multiple GPUs - - (3) How to do decoding after training, with 1best and attention decoder rescoring + - (3) How to do decoding after training, with ctc-decoding, 1best and attention decoder rescoring - (4) How to use a pre-trained model, provided by us Data preparation @@ -623,3 +623,125 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained **Congratulations!** You have finished the aishell ASR recipe with conformer CTC models in ``icefall``. + + +If you want to deploy your trained model in C++, please read the following section. + +Deployment with C++ +------------------- + +This section describes how to deploy the pre-trained model in C++, without +Python dependencies. + +.. HINT:: + + At present, it does NOT support streaming decoding. + +First, let us compile k2 from source: + +.. code-block:: bash + + $ cd $HOME + $ git clone https://github.com/k2-fsa/k2 + $ cd k2 + $ git checkout v2.0-pre + +.. CAUTION:: + + You have to switch to the branch ``v2.0-pre``! + +.. code-block:: bash + + $ mkdir build-release + $ cd build-release + $ cmake -DCMAKE_BUILD_TYPE=Release .. + $ make -j hlg_decode + + # You will find four binaries in `./bin`, i.e. ./bin/hlg_decode, + +Now you are ready to go! + +Assume you have run: + + .. code-block:: bash + + $ cd k2/build-release + $ ln -s /path/to/icefall-asr-aishell-conformer-ctc ./ + +To view the usage of ``./bin/hlg_decode``, run: + +.. code-block:: + + $ ./bin/hlg_decode + +It will show you the following message: + +.. code-block:: bash + + Please provide --nn_model + + This file implements decoding with an HLG decoding graph. + + Usage: + ./bin/hlg_decode \ + --use_gpu true \ + --nn_model \ + --hlg \ + --word_table \ + \ + \ + + + To see all possible options, use + ./bin/hlg_decode --help + + Caution: + - Only sound files (*.wav) with single channel are supported. + - It assumes the model is conformer_ctc/transformer.py from icefall. + If you use a different model, you have to change the code + related to `model.forward` in this file. + + +HLG decoding +^^^^^^^^^^^^ + +.. code-block:: bash + + ./bin/hlg_decode \ + --use_gpu true \ + --nn_model icefall_asr_aishell_conformer_ctc/exp/cpu_jit.pt \ + --hlg icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt \ + --word_table icefall_asr_aishell_conformer_ctc/data/lang_char/words.txt \ + icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav \ + icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav \ + icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav + +The output is: + +.. code-block:: + + 2021-11-18 14:48:20.89 [I] k2/torch/bin/hlg_decode.cu:115:int main(int, char**) Device: cpu + 2021-11-18 14:48:20.89 [I] k2/torch/bin/hlg_decode.cu:124:int main(int, char**) Load wave files + 2021-11-18 14:48:20.97 [I] k2/torch/bin/hlg_decode.cu:131:int main(int, char**) Build Fbank computer + 2021-11-18 14:48:20.98 [I] k2/torch/bin/hlg_decode.cu:142:int main(int, char**) Compute features + 2021-11-18 14:48:20.115 [I] k2/torch/bin/hlg_decode.cu:150:int main(int, char**) Load neural network model + 2021-11-18 14:48:20.693 [I] k2/torch/bin/hlg_decode.cu:165:int main(int, char**) Compute nnet_output + 2021-11-18 14:48:23.182 [I] k2/torch/bin/hlg_decode.cu:180:int main(int, char**) Load icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt + 2021-11-18 14:48:33.489 [I] k2/torch/bin/hlg_decode.cu:185:int main(int, char**) Decoding + 2021-11-18 14:48:45.217 [I] k2/torch/bin/hlg_decode.cu:216:int main(int, char**) + Decoding result: + + icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav + 甚至 出现 交易 几乎 停止 的 情况 + + icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav + 一二 线 城市 虽然 也 处于 调整 中 + + icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav + 但 因为 聚集 了 过多 公共 资源 + +There is a Colab notebook showing you how to run a torch scripted model in C++. +Please see |aishell asr conformer ctc torch script colab notebook| + +.. |aishell asr conformer ctc torch script colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/drive/1Vh7RER7saTW01DtNbvr7CY7ovNZgmfWz?usp=sharing diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index 58ce39cca..dc593eeb9 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -38,14 +38,13 @@ from icefall.decode import ( one_best_decoding, rescore_with_attention_decoder, ) +from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - get_env_info, get_texts, setup_logger, store_transcripts, - str2bool, write_error_stats, ) @@ -113,17 +112,6 @@ def get_parser(): """, ) - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. - pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - parser.add_argument( "--exp-dir", type=str, @@ -544,13 +532,6 @@ def main(): model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) - return - model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py new file mode 100644 index 000000000..42b8c29e7 --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/export.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from conformer import Conformer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=84, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=25, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""It contains language related input files such as "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 4, + "num_decoder_layers": 6, + } + ) + return params + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/conformer_ctc/label_smoothing.py b/egs/aishell/ASR/conformer_ctc/label_smoothing.py new file mode 100644 index 000000000..cdc85ce9a --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1,98 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class LabelSmoothingLoss(torch.nn.Module): + """ + Implement the LabelSmoothingLoss proposed in the following paper + https://arxiv.org/pdf/1512.00567.pdf + (Rethinking the Inception Architecture for Computer Vision) + + """ + + def __init__( + self, + ignore_index: int = -1, + label_smoothing: float = 0.1, + reduction: str = "sum", + ) -> None: + """ + Args: + ignore_index: + ignored class id + label_smoothing: + smoothing rate (0.0 means the conventional cross entropy loss) + reduction: + It has the same meaning as the reduction in + `torch.nn.CrossEntropyLoss`. It can be one of the following three + values: (1) "none": No reduction will be applied. (2) "mean": the + mean of the output is taken. (3) "sum": the output will be summed. + """ + super().__init__() + assert 0.0 <= label_smoothing < 1.0 + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.ignore_index of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.ndim == 3 + assert target.ndim == 2 + assert x.shape[:2] == target.shape + num_classes = x.size(-1) + x = x.reshape(-1, num_classes) + # Now x is of shape (N*T, C) + + # We don't want to change target in-place below, + # so we make a copy of it here + target = target.clone().reshape(-1) + + ignored = target == self.ignore_index + target[ignored] = 0 + + true_dist = torch.nn.functional.one_hot( + target, num_classes=num_classes + ).to(x) + + true_dist = ( + true_dist * (1 - self.label_smoothing) + + self.label_smoothing / num_classes + ) + # Set the value of ignored indexes to 0 + true_dist[ignored] = 0 + + loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) + if self.reduction == "sum": + return loss.sum() + elif self.reduction == "mean": + return loss.sum() / (~ignored).sum() + else: + return loss.sum(dim=-1) diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py index 8657968ec..27776bc24 100755 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -34,7 +34,7 @@ from icefall.decode import ( one_best_decoding, rescore_with_attention_decoder, ) -from icefall.utils import AttributeDict, get_env_info, get_texts +from icefall.utils import AttributeDict, get_texts def get_parser(): @@ -190,7 +190,6 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, - "env_info": get_env_info(), } ) return params diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index 94367ed4e..629d7a373 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -38,12 +38,12 @@ from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, encode_supervisions, - get_env_info, setup_logger, str2bool, ) diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py index 88b10b23d..f93914aaa 100644 --- a/egs/aishell/ASR/conformer_ctc/transformer.py +++ b/egs/aishell/ASR/conformer_ctc/transformer.py @@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn +from label_smoothing import LabelSmoothingLoss from subsampling import Conv2dSubsampling, VggSubsampling from torch.nn.utils.rnn import pad_sequence @@ -83,8 +84,8 @@ class Transformer(nn.Module): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - # self.encoder_embed converts the input of shape [N, T, num_classes] - # to the shape [N, T//subsampling_factor, d_model]. + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_classes -> d_model @@ -152,7 +153,7 @@ class Transformer(nn.Module): d_model, self.decoder_num_class ) - self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) + self.decoder_criterion = LabelSmoothingLoss() else: self.decoder_criterion = None @@ -162,7 +163,7 @@ class Transformer(nn.Module): """ Args: x: - The input tensor. Its shape is [N, T, C]. + The input tensor. Its shape is (N, T, C). supervision: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -171,17 +172,17 @@ class Transformer(nn.Module): Returns: Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is [N, T, C] - - Encoder output with shape [T, N, C]. It can be used as key and + - CTC output for ctc decoding. Its shape is (N, T, C) + - Encoder output with shape (T, N, C). It can be used as key and value for the decoder. - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is [N, T]. + memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) encoder_memory, memory_key_padding_mask = self.run_encoder( x, supervision ) @@ -195,7 +196,7 @@ class Transformer(nn.Module): Args: x: - The model input. Its shape is [N, T, C]. + The model input. Its shape is (N, T, C). supervisions: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -206,8 +207,8 @@ class Transformer(nn.Module): padding mask for the decoder. Returns: Return a tuple with two tensors: - - The encoder output, with shape [T, N, C] - - encoder padding mask, with shape [N, T]. + - The encoder output, with shape (T, N, C) + - encoder padding mask, with shape (N, T). The mask is None if `supervisions` is None. It is used as memory key padding mask in the decoder. """ @@ -225,17 +226,18 @@ class Transformer(nn.Module): Args: x: The output tensor from the transformer encoder. - Its shape is [T, N, C] + Its shape is (T, N, C) Returns: Return a tensor that can be used for CTC decoding. - Its shape is [N, T, C] + Its shape is (N, T, C) """ x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) return x + @torch.jit.export def decoder_forward( self, memory: torch.Tensor, @@ -247,7 +249,7 @@ class Transformer(nn.Module): """ Args: memory: - It's the output of the encoder with shape [T, N, C] + It's the output of the encoder with shape (T, N, C) memory_key_padding_mask: The padding mask from the encoder. token_ids: @@ -264,11 +266,15 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) @@ -301,18 +307,19 @@ class Transformer(nn.Module): return decoder_loss + @torch.jit.export def decoder_nll( self, memory: torch.Tensor, memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], + token_ids: List[torch.Tensor], sos_id: int, eos_id: int, ) -> torch.Tensor: """ Args: memory: - It's the output of the encoder with shape [T, N, C] + It's the output of the encoder with shape (T, N, C) memory_key_padding_mask: The padding mask from the encoder. token_ids: @@ -328,14 +335,23 @@ class Transformer(nn.Module): """ # The common part between this function and decoder_forward could be # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) @@ -649,25 +665,25 @@ class PositionalEncoding(nn.Module): self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = nn.Dropout(p=dropout) - self.pe = None + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) def extend_pe(self, x: torch.Tensor) -> None: """Extend the time t in the positional encoding if required. - The shape of `self.pe` is [1, T1, d_model]. The shape of the input x - is [N, T, d_model]. If T > T1, then we change the shape of self.pe - to [N, T, d_model]. Otherwise, nothing is done. + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. Args: x: - It is a tensor of shape [N, T, C]. + It is a tensor of shape (N, T, C). Returns: Return None. """ if self.pe is not None: if self.pe.size(1) >= x.size(1): - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) + self.pe = self.pe.to(dtype=x.dtype, device=x.device) return pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) @@ -678,7 +694,7 @@ class PositionalEncoding(nn.Module): pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) - # Now pe is of shape [1, T, d_model], where T is x.size(1) + # Now pe is of shape (1, T, d_model), where T is x.size(1) self.pe = pe.to(device=x.device, dtype=x.dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -687,10 +703,10 @@ class PositionalEncoding(nn.Module): Args: x: - Its shape is [N, T, C] + Its shape is (N, T, C) Returns: - Return a tensor of shape [N, T, C] + Return a tensor of shape (N, T, C) """ self.extend_pe(x) x = x * self.xscale + self.pe[:, : x.size(1), :] @@ -784,73 +800,6 @@ class Noam(object): setattr(self, key, value) -class LabelSmoothingLoss(nn.Module): - """ - Label-smoothing loss. KL-divergence between - q_{smoothed ground truth prob.}(w) - and p_{prob. computed by model}(w) is minimized. - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa - - Args: - size: the number of class - padding_idx: padding_idx: ignored class id - smoothing: smoothing rate (0.0 means the conventional CE) - normalize_length: normalize loss by sequence length if True - criterion: loss function to be smoothed - """ - - def __init__( - self, - size: int, - padding_idx: int = -1, - smoothing: float = 0.1, - normalize_length: bool = False, - criterion: nn.Module = nn.KLDivLoss(reduction="none"), - ) -> None: - """Construct an LabelSmoothingLoss object.""" - super(LabelSmoothingLoss, self).__init__() - self.criterion = criterion - self.padding_idx = padding_idx - assert 0.0 < smoothing <= 1.0 - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.size = size - self.true_dist = None - self.normalize_length = normalize_length - - def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute loss between x and target. - - Args: - x: - prediction of dimension - (batch_size, input_length, number_of_classes). - target: - target masked with self.padding_id of - dimension (batch_size, input_length). - - Returns: - A scalar tensor containing the loss without normalization. - """ - assert x.size(2) == self.size - # batch_size = x.size(0) - x = x.view(-1, self.size) - target = target.view(-1) - with torch.no_grad(): - true_dist = x.clone() - true_dist.fill_(self.smoothing / (self.size - 1)) - ignore = target == self.padding_idx # (B,) - total = len(target) - ignore.sum().item() - target = target.masked_fill(ignore, 0) # avoid -1 index - true_dist.scatter_(1, target.unsqueeze(1), self.confidence) - kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) - # denom = total if self.normalize_length else batch_size - denom = total if self.normalize_length else 1 - return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom - - def encoder_padding_mask( max_len: int, supervisions: Optional[Supervisions] = None ) -> Optional[torch.Tensor]: @@ -972,10 +921,7 @@ def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: Return a new list-of-list, where each sublist starts with SOS ID. """ - ans = [] - for utt in token_ids: - ans.append([sos_id] + utt) - return ans + return [[sos_id] + utt for utt in token_ids] def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: @@ -992,7 +938,9 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: Return a new list-of-list, where each sublist ends with EOS ID. """ - ans = [] - for utt in token_ids: - ans.append(utt + [eos_id]) - return ans + return [utt + [eos_id] for utt in token_ids] + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index ad72a88e7..2b2967506 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -28,12 +28,12 @@ from conformer import Conformer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.decode import one_best_decoding +from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, encode_supervisions, get_alignments, - get_env_info, save_alignments, setup_logger, ) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index f5ffe026e..63aed9358 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -40,14 +40,13 @@ from icefall.decode import ( rescore_with_n_best_list, rescore_with_whole_lattice, ) +from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - get_env_info, get_texts, setup_logger, store_transcripts, - str2bool, write_error_stats, ) @@ -122,17 +121,6 @@ def get_parser(): """, ) - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. - pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - parser.add_argument( "--exp-dir", type=str, @@ -671,13 +659,6 @@ def main(): model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) - return - model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 95589b82b..28724e1eb 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -36,7 +36,7 @@ from icefall.decode import ( rescore_with_attention_decoder, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_env_info, get_texts +from icefall.utils import AttributeDict, get_texts def get_parser(): @@ -256,7 +256,6 @@ def main(): params.num_decoder_layers = 0 params.update(vars(args)) - params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 46ea5c60c..c6063fade 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -41,12 +41,12 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, encode_supervisions, - get_env_info, setup_logger, str2bool, ) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 72ee2ff0b..636cb9388 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -36,10 +36,10 @@ from icefall.decode import ( rescore_with_n_best_list, rescore_with_whole_lattice, ) +from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - get_env_info, get_texts, setup_logger, store_transcripts, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index e0d6a7a60..2baeb6bba 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -34,7 +34,7 @@ from icefall.decode import ( one_best_decoding, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_env_info, get_texts +from icefall.utils import AttributeDict, get_texts def get_parser(): @@ -159,7 +159,6 @@ def main(): params = get_params() params.update(vars(args)) - params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 7904b0e61..99fe170d2 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, encode_supervisions, - get_env_info, setup_logger, str2bool, ) diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py index 024051709..7da285944 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py @@ -34,7 +34,7 @@ from icefall.decode import ( one_best_decoding, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_env_info, get_texts +from icefall.utils import AttributeDict, get_texts def get_parser(): @@ -159,7 +159,6 @@ def main(): params = get_params() params.update(vars(args)) - params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py index 53b49dec2..9ac4743b4 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/train.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py @@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, encode_supervisions, - get_env_info, setup_logger, str2bool, ) diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py index 95fd84f24..5f478da1c 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py @@ -34,7 +34,7 @@ from icefall.decode import ( one_best_decoding, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_env_info, get_texts +from icefall.utils import AttributeDict, get_texts def get_parser(): @@ -159,7 +159,6 @@ def main(): params = get_params() params.update(vars(args)) - params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py index a5c8eb26c..2a6ff4787 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/train.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py @@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, encode_supervisions, - get_env_info, setup_logger, str2bool, ) diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 9df019bf5..a6a57a2fc 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -14,10 +14,10 @@ from model import Tdnn from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.decode import get_lattice, one_best_decoding +from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - get_env_info, get_texts, setup_logger, store_transcripts, diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index 75758b984..14220be19 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -29,7 +29,7 @@ from model import Tdnn from torch.nn.utils.rnn import pad_sequence from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import AttributeDict, get_env_info, get_texts +from icefall.utils import AttributeDict, get_texts def get_parser(): @@ -116,7 +116,6 @@ def main(): params = get_params() params.update(vars(args)) - params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index e24061fa1..30d83666a 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -22,15 +22,10 @@ from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_env_info, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool def get_parser(): diff --git a/icefall/env.py b/icefall/env.py new file mode 100644 index 000000000..fd56ad8c2 --- /dev/null +++ b/icefall/env.py @@ -0,0 +1,106 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Wei Kang) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import subprocess +import sys +from pathlib import Path +from typing import Any, Dict + +import k2 +import k2.version +import lhotse +import torch + + +def get_git_sha1(): + git_commit = ( + subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + dirty_commit = ( + len( + subprocess.run( + ["git", "diff", "--shortstat"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + > 0 + ) + git_commit = ( + git_commit + "-dirty" if dirty_commit else git_commit + "-clean" + ) + return git_commit + + +def get_git_date(): + git_date = ( + subprocess.run( + ["git", "log", "-1", "--format=%ad", "--date=local"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + return git_date + + +def get_git_branch_name(): + git_date = ( + subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + return git_date + + +def get_env_info() -> Dict[str, Any]: + """Get the environment information.""" + return { + "k2-version": k2.version.__version__, + "k2-build-type": k2.version.__build_type__, + "k2-with-cuda": k2.with_cuda, + "k2-git-sha1": k2.version.__git_sha1__, + "k2-git-date": k2.version.__git_date__, + "lhotse-version": lhotse.__version__, + "torch-cuda-available": torch.cuda.is_available(), + "torch-cuda-version": torch.version.cuda, + "python-version": sys.version[:3], + "icefall-git-branch": get_git_branch_name(), + "icefall-git-sha1": get_git_sha1(), + "icefall-git-date": get_git_date(), + "icefall-path": str(Path(__file__).resolve().parent.parent), + "k2-path": str(Path(k2.__file__).resolve()), + "lhotse-path": str(Path(lhotse.__file__).resolve()), + } diff --git a/icefall/utils.py b/icefall/utils.py index 1d4aabd72..ba9436fa4 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -21,17 +21,15 @@ import collections import logging import os import subprocess -import sys from collections import defaultdict from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union +from typing import Dict, Iterable, List, TextIO, Tuple, Union import k2 import k2.version import kaldialign -import lhotse import torch import torch.distributed as dist from torch.utils.tensorboard import SummaryWriter @@ -137,85 +135,6 @@ def setup_logger( logging.getLogger("").addHandler(console) -def get_git_sha1(): - git_commit = ( - subprocess.run( - ["git", "rev-parse", "--short", "HEAD"], - check=True, - stdout=subprocess.PIPE, - ) - .stdout.decode() - .rstrip("\n") - .strip() - ) - dirty_commit = ( - len( - subprocess.run( - ["git", "diff", "--shortstat"], - check=True, - stdout=subprocess.PIPE, - ) - .stdout.decode() - .rstrip("\n") - .strip() - ) - > 0 - ) - git_commit = ( - git_commit + "-dirty" if dirty_commit else git_commit + "-clean" - ) - return git_commit - - -def get_git_date(): - git_date = ( - subprocess.run( - ["git", "log", "-1", "--format=%ad", "--date=local"], - check=True, - stdout=subprocess.PIPE, - ) - .stdout.decode() - .rstrip("\n") - .strip() - ) - return git_date - - -def get_git_branch_name(): - git_date = ( - subprocess.run( - ["git", "rev-parse", "--abbrev-ref", "HEAD"], - check=True, - stdout=subprocess.PIPE, - ) - .stdout.decode() - .rstrip("\n") - .strip() - ) - return git_date - - -def get_env_info() -> Dict[str, Any]: - """Get the environment information.""" - return { - "k2-version": k2.version.__version__, - "k2-build-type": k2.version.__build_type__, - "k2-with-cuda": k2.with_cuda, - "k2-git-sha1": k2.version.__git_sha1__, - "k2-git-date": k2.version.__git_date__, - "lhotse-version": lhotse.__version__, - "torch-cuda-available": torch.cuda.is_available(), - "torch-cuda-version": torch.version.cuda, - "python-version": sys.version[:3], - "icefall-git-branch": get_git_branch_name(), - "icefall-git-sha1": get_git_sha1(), - "icefall-git-date": get_git_date(), - "icefall-path": str(Path(__file__).resolve().parent.parent), - "k2-path": str(Path(k2.__file__).resolve()), - "lhotse-path": str(Path(lhotse.__file__).resolve()), - } - - class AttributeDict(dict): def __getattr__(self, key): if key in self: diff --git a/test/test_utils.py b/test/test_utils.py index b8c742c5a..01916bc59 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -20,12 +20,8 @@ import k2 import pytest import torch -from icefall.utils import ( - AttributeDict, - encode_supervisions, - get_env_info, - get_texts, -) +from icefall.env import get_env_info +from icefall.utils import AttributeDict, encode_supervisions, get_texts @pytest.fixture