mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Add torch script support for Aishell and update documents (#124)
* Add aishell recipe * Remove unnecessary code and update docs * adapt to k2 v1.7, add docs and results * Update conformer ctc model * Update docs, pretrained.py & results * Fix code style * Fix code style * Fix code style * Minor fix * Minor fix * Fix pretrained.py * Update pretrained model & corresponding docs * Export torch script model for Aishell * Add C++ deployment docs * Minor fixes * Fix unit test * Update Readme
This commit is contained in:
parent
30c43b7f69
commit
4151cca147
31
README.md
31
README.md
@ -12,10 +12,11 @@ for installation.
|
|||||||
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/index.html>
|
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/index.html>
|
||||||
for more information.
|
for more information.
|
||||||
|
|
||||||
We provide three recipes at present:
|
We provide four recipes at present:
|
||||||
|
|
||||||
- [yesno][yesno]
|
- [yesno][yesno]
|
||||||
- [LibriSpeech][librispeech]
|
- [LibriSpeech][librispeech]
|
||||||
|
- [Aishell][aishell]
|
||||||
- [TIMIT][timit]
|
- [TIMIT][timit]
|
||||||
|
|
||||||
### yesno
|
### yesno
|
||||||
@ -57,6 +58,31 @@ The WER for this model is:
|
|||||||
|
|
||||||
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
|
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
|
||||||
|
|
||||||
|
### Aishell
|
||||||
|
|
||||||
|
We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc]
|
||||||
|
and [TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc].
|
||||||
|
|
||||||
|
#### Conformer CTC Model
|
||||||
|
|
||||||
|
The best CER we currently have is:
|
||||||
|
|
||||||
|
| | test |
|
||||||
|
|-----|------|
|
||||||
|
| CER | 4.26 |
|
||||||
|
|
||||||
|
|
||||||
|
We provide a Colab notebook to run a pre-trained conformer CTC model: [](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing)
|
||||||
|
|
||||||
|
#### TDNN LSTM CTC Model
|
||||||
|
|
||||||
|
The CER for this model is:
|
||||||
|
|
||||||
|
| | test |
|
||||||
|
|-----|-------|
|
||||||
|
| CER | 10.16 |
|
||||||
|
|
||||||
|
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1qULaGvXq7PCu_P61oubfz9b53JzY4H3z?usp=sharing)
|
||||||
|
|
||||||
### TIMIT
|
### TIMIT
|
||||||
|
|
||||||
@ -99,9 +125,12 @@ Please see: [ How to prepare data for training and decoding
|
- (1) How to prepare data for training and decoding
|
||||||
- (2) How to start the training, either with a single GPU or multiple GPUs
|
- (2) How to start the training, either with a single GPU or multiple GPUs
|
||||||
- (3) How to do decoding after training, with 1best and attention decoder rescoring
|
- (3) How to do decoding after training, with ctc-decoding, 1best and attention decoder rescoring
|
||||||
- (4) How to use a pre-trained model, provided by us
|
- (4) How to use a pre-trained model, provided by us
|
||||||
|
|
||||||
Data preparation
|
Data preparation
|
||||||
@ -623,3 +623,125 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained
|
|||||||
|
|
||||||
**Congratulations!** You have finished the aishell ASR recipe with
|
**Congratulations!** You have finished the aishell ASR recipe with
|
||||||
conformer CTC models in ``icefall``.
|
conformer CTC models in ``icefall``.
|
||||||
|
|
||||||
|
|
||||||
|
If you want to deploy your trained model in C++, please read the following section.
|
||||||
|
|
||||||
|
Deployment with C++
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
This section describes how to deploy the pre-trained model in C++, without
|
||||||
|
Python dependencies.
|
||||||
|
|
||||||
|
.. HINT::
|
||||||
|
|
||||||
|
At present, it does NOT support streaming decoding.
|
||||||
|
|
||||||
|
First, let us compile k2 from source:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd $HOME
|
||||||
|
$ git clone https://github.com/k2-fsa/k2
|
||||||
|
$ cd k2
|
||||||
|
$ git checkout v2.0-pre
|
||||||
|
|
||||||
|
.. CAUTION::
|
||||||
|
|
||||||
|
You have to switch to the branch ``v2.0-pre``!
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ mkdir build-release
|
||||||
|
$ cd build-release
|
||||||
|
$ cmake -DCMAKE_BUILD_TYPE=Release ..
|
||||||
|
$ make -j hlg_decode
|
||||||
|
|
||||||
|
# You will find four binaries in `./bin`, i.e. ./bin/hlg_decode,
|
||||||
|
|
||||||
|
Now you are ready to go!
|
||||||
|
|
||||||
|
Assume you have run:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
$ cd k2/build-release
|
||||||
|
$ ln -s /path/to/icefall-asr-aishell-conformer-ctc ./
|
||||||
|
|
||||||
|
To view the usage of ``./bin/hlg_decode``, run:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
$ ./bin/hlg_decode
|
||||||
|
|
||||||
|
It will show you the following message:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
Please provide --nn_model
|
||||||
|
|
||||||
|
This file implements decoding with an HLG decoding graph.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
./bin/hlg_decode \
|
||||||
|
--use_gpu true \
|
||||||
|
--nn_model <path to torch scripted pt file> \
|
||||||
|
--hlg <path to HLG.pt> \
|
||||||
|
--word_table <path to words.txt> \
|
||||||
|
<path to foo.wav> \
|
||||||
|
<path to bar.wav> \
|
||||||
|
<more waves if any>
|
||||||
|
|
||||||
|
To see all possible options, use
|
||||||
|
./bin/hlg_decode --help
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
- Only sound files (*.wav) with single channel are supported.
|
||||||
|
- It assumes the model is conformer_ctc/transformer.py from icefall.
|
||||||
|
If you use a different model, you have to change the code
|
||||||
|
related to `model.forward` in this file.
|
||||||
|
|
||||||
|
|
||||||
|
HLG decoding
|
||||||
|
^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./bin/hlg_decode \
|
||||||
|
--use_gpu true \
|
||||||
|
--nn_model icefall_asr_aishell_conformer_ctc/exp/cpu_jit.pt \
|
||||||
|
--hlg icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt \
|
||||||
|
--word_table icefall_asr_aishell_conformer_ctc/data/lang_char/words.txt \
|
||||||
|
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav \
|
||||||
|
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav \
|
||||||
|
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav
|
||||||
|
|
||||||
|
The output is:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
2021-11-18 14:48:20.89 [I] k2/torch/bin/hlg_decode.cu:115:int main(int, char**) Device: cpu
|
||||||
|
2021-11-18 14:48:20.89 [I] k2/torch/bin/hlg_decode.cu:124:int main(int, char**) Load wave files
|
||||||
|
2021-11-18 14:48:20.97 [I] k2/torch/bin/hlg_decode.cu:131:int main(int, char**) Build Fbank computer
|
||||||
|
2021-11-18 14:48:20.98 [I] k2/torch/bin/hlg_decode.cu:142:int main(int, char**) Compute features
|
||||||
|
2021-11-18 14:48:20.115 [I] k2/torch/bin/hlg_decode.cu:150:int main(int, char**) Load neural network model
|
||||||
|
2021-11-18 14:48:20.693 [I] k2/torch/bin/hlg_decode.cu:165:int main(int, char**) Compute nnet_output
|
||||||
|
2021-11-18 14:48:23.182 [I] k2/torch/bin/hlg_decode.cu:180:int main(int, char**) Load icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt
|
||||||
|
2021-11-18 14:48:33.489 [I] k2/torch/bin/hlg_decode.cu:185:int main(int, char**) Decoding
|
||||||
|
2021-11-18 14:48:45.217 [I] k2/torch/bin/hlg_decode.cu:216:int main(int, char**)
|
||||||
|
Decoding result:
|
||||||
|
|
||||||
|
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav
|
||||||
|
甚至 出现 交易 几乎 停止 的 情况
|
||||||
|
|
||||||
|
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav
|
||||||
|
一二 线 城市 虽然 也 处于 调整 中
|
||||||
|
|
||||||
|
icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav
|
||||||
|
但 因为 聚集 了 过多 公共 资源
|
||||||
|
|
||||||
|
There is a Colab notebook showing you how to run a torch scripted model in C++.
|
||||||
|
Please see |aishell asr conformer ctc torch script colab notebook|
|
||||||
|
|
||||||
|
.. |aishell asr conformer ctc torch script colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
|
||||||
|
:target: https://colab.research.google.com/drive/1Vh7RER7saTW01DtNbvr7CY7ovNZgmfWz?usp=sharing
|
||||||
|
@ -38,14 +38,13 @@ from icefall.decode import (
|
|||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
)
|
)
|
||||||
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_env_info,
|
|
||||||
get_texts,
|
get_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -113,17 +112,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--export",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""When enabled, the averaged model is saved to
|
|
||||||
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
|
||||||
pretrained.pt contains a dict {"model": model.state_dict()},
|
|
||||||
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -544,13 +532,6 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
if params.export:
|
|
||||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
|
||||||
torch.save(
|
|
||||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
165
egs/aishell/ASR/conformer_ctc/export.py
Normal file
165
egs/aishell/ASR/conformer_ctc/export.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This script converts several saved checkpoints
|
||||||
|
# to a single one using model averaging.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from conformer import Conformer
|
||||||
|
|
||||||
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import AttributeDict, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=84,
|
||||||
|
help="It specifies the checkpoint to use for decoding."
|
||||||
|
"Note: Epoch counts from 0.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch'. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_ctc/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="""It contains language related input files such as "lexicon.txt"
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="""True to save a model after applying torch.jit.script.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"feature_dim": 80,
|
||||||
|
"subsampling_factor": 4,
|
||||||
|
"use_feat_batchnorm": True,
|
||||||
|
"attention_dim": 512,
|
||||||
|
"nhead": 4,
|
||||||
|
"num_decoder_layers": 6,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
num_classes = max_token_id + 1 # +1 for the blank
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
model = Conformer(
|
||||||
|
num_features=params.feature_dim,
|
||||||
|
nhead=params.nhead,
|
||||||
|
d_model=params.attention_dim,
|
||||||
|
num_classes=num_classes,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
|
vgg_frontend=False,
|
||||||
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if start >= 0:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.load_state_dict(average_checkpoints(filenames))
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if params.jit:
|
||||||
|
logging.info("Using torch.jit.script")
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
model.save(str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
else:
|
||||||
|
logging.info("Not using torch.jit.script")
|
||||||
|
# Save it using a format so that it can be loaded
|
||||||
|
# by :func:`load_checkpoint`
|
||||||
|
filename = params.exp_dir / "pretrained.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
98
egs/aishell/ASR/conformer_ctc/label_smoothing.py
Normal file
98
egs/aishell/ASR/conformer_ctc/label_smoothing.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class LabelSmoothingLoss(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implement the LabelSmoothingLoss proposed in the following paper
|
||||||
|
https://arxiv.org/pdf/1512.00567.pdf
|
||||||
|
(Rethinking the Inception Architecture for Computer Vision)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ignore_index: int = -1,
|
||||||
|
label_smoothing: float = 0.1,
|
||||||
|
reduction: str = "sum",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
ignore_index:
|
||||||
|
ignored class id
|
||||||
|
label_smoothing:
|
||||||
|
smoothing rate (0.0 means the conventional cross entropy loss)
|
||||||
|
reduction:
|
||||||
|
It has the same meaning as the reduction in
|
||||||
|
`torch.nn.CrossEntropyLoss`. It can be one of the following three
|
||||||
|
values: (1) "none": No reduction will be applied. (2) "mean": the
|
||||||
|
mean of the output is taken. (3) "sum": the output will be summed.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert 0.0 <= label_smoothing < 1.0
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.label_smoothing = label_smoothing
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute loss between x and target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
prediction of dimension
|
||||||
|
(batch_size, input_length, number_of_classes).
|
||||||
|
target:
|
||||||
|
target masked with self.ignore_index of
|
||||||
|
dimension (batch_size, input_length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A scalar tensor containing the loss without normalization.
|
||||||
|
"""
|
||||||
|
assert x.ndim == 3
|
||||||
|
assert target.ndim == 2
|
||||||
|
assert x.shape[:2] == target.shape
|
||||||
|
num_classes = x.size(-1)
|
||||||
|
x = x.reshape(-1, num_classes)
|
||||||
|
# Now x is of shape (N*T, C)
|
||||||
|
|
||||||
|
# We don't want to change target in-place below,
|
||||||
|
# so we make a copy of it here
|
||||||
|
target = target.clone().reshape(-1)
|
||||||
|
|
||||||
|
ignored = target == self.ignore_index
|
||||||
|
target[ignored] = 0
|
||||||
|
|
||||||
|
true_dist = torch.nn.functional.one_hot(
|
||||||
|
target, num_classes=num_classes
|
||||||
|
).to(x)
|
||||||
|
|
||||||
|
true_dist = (
|
||||||
|
true_dist * (1 - self.label_smoothing)
|
||||||
|
+ self.label_smoothing / num_classes
|
||||||
|
)
|
||||||
|
# Set the value of ignored indexes to 0
|
||||||
|
true_dist[ignored] = 0
|
||||||
|
|
||||||
|
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
|
||||||
|
if self.reduction == "sum":
|
||||||
|
return loss.sum()
|
||||||
|
elif self.reduction == "mean":
|
||||||
|
return loss.sum() / (~ignored).sum()
|
||||||
|
else:
|
||||||
|
return loss.sum(dim=-1)
|
@ -34,7 +34,7 @@ from icefall.decode import (
|
|||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
)
|
)
|
||||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -190,7 +190,6 @@ def get_params() -> AttributeDict:
|
|||||||
"min_active_states": 30,
|
"min_active_states": 30,
|
||||||
"max_active_states": 10000,
|
"max_active_states": 10000,
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
|
@ -38,12 +38,12 @@ from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
|||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
get_env_info,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from label_smoothing import LabelSmoothingLoss
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
@ -83,8 +84,8 @@ class Transformer(nn.Module):
|
|||||||
if subsampling_factor != 4:
|
if subsampling_factor != 4:
|
||||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||||
|
|
||||||
# self.encoder_embed converts the input of shape [N, T, num_classes]
|
# self.encoder_embed converts the input of shape (N, T, num_classes)
|
||||||
# to the shape [N, T//subsampling_factor, d_model].
|
# to the shape (N, T//subsampling_factor, d_model).
|
||||||
# That is, it does two things simultaneously:
|
# That is, it does two things simultaneously:
|
||||||
# (1) subsampling: T -> T//subsampling_factor
|
# (1) subsampling: T -> T//subsampling_factor
|
||||||
# (2) embedding: num_classes -> d_model
|
# (2) embedding: num_classes -> d_model
|
||||||
@ -152,7 +153,7 @@ class Transformer(nn.Module):
|
|||||||
d_model, self.decoder_num_class
|
d_model, self.decoder_num_class
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class)
|
self.decoder_criterion = LabelSmoothingLoss()
|
||||||
else:
|
else:
|
||||||
self.decoder_criterion = None
|
self.decoder_criterion = None
|
||||||
|
|
||||||
@ -162,7 +163,7 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
The input tensor. Its shape is [N, T, C].
|
The input tensor. Its shape is (N, T, C).
|
||||||
supervision:
|
supervision:
|
||||||
Supervision in lhotse format.
|
Supervision in lhotse format.
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||||
@ -171,17 +172,17 @@ class Transformer(nn.Module):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing 3 tensors:
|
Return a tuple containing 3 tensors:
|
||||||
- CTC output for ctc decoding. Its shape is [N, T, C]
|
- CTC output for ctc decoding. Its shape is (N, T, C)
|
||||||
- Encoder output with shape [T, N, C]. It can be used as key and
|
- Encoder output with shape (T, N, C). It can be used as key and
|
||||||
value for the decoder.
|
value for the decoder.
|
||||||
- Encoder output padding mask. It can be used as
|
- Encoder output padding mask. It can be used as
|
||||||
memory_key_padding_mask for the decoder. Its shape is [N, T].
|
memory_key_padding_mask for the decoder. Its shape is (N, T).
|
||||||
It is None if `supervision` is None.
|
It is None if `supervision` is None.
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
if self.use_feat_batchnorm:
|
||||||
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
|
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||||
x = self.feat_batchnorm(x)
|
x = self.feat_batchnorm(x)
|
||||||
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||||
x, supervision
|
x, supervision
|
||||||
)
|
)
|
||||||
@ -195,7 +196,7 @@ class Transformer(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
The model input. Its shape is [N, T, C].
|
The model input. Its shape is (N, T, C).
|
||||||
supervisions:
|
supervisions:
|
||||||
Supervision in lhotse format.
|
Supervision in lhotse format.
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||||
@ -206,8 +207,8 @@ class Transformer(nn.Module):
|
|||||||
padding mask for the decoder.
|
padding mask for the decoder.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple with two tensors:
|
Return a tuple with two tensors:
|
||||||
- The encoder output, with shape [T, N, C]
|
- The encoder output, with shape (T, N, C)
|
||||||
- encoder padding mask, with shape [N, T].
|
- encoder padding mask, with shape (N, T).
|
||||||
The mask is None if `supervisions` is None.
|
The mask is None if `supervisions` is None.
|
||||||
It is used as memory key padding mask in the decoder.
|
It is used as memory key padding mask in the decoder.
|
||||||
"""
|
"""
|
||||||
@ -225,17 +226,18 @@ class Transformer(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
The output tensor from the transformer encoder.
|
The output tensor from the transformer encoder.
|
||||||
Its shape is [T, N, C]
|
Its shape is (T, N, C)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor that can be used for CTC decoding.
|
Return a tensor that can be used for CTC decoding.
|
||||||
Its shape is [N, T, C]
|
Its shape is (N, T, C)
|
||||||
"""
|
"""
|
||||||
x = self.encoder_output_layer(x)
|
x = self.encoder_output_layer(x)
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
|
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
def decoder_forward(
|
def decoder_forward(
|
||||||
self,
|
self,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
@ -247,7 +249,7 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
memory:
|
memory:
|
||||||
It's the output of the encoder with shape [T, N, C]
|
It's the output of the encoder with shape (T, N, C)
|
||||||
memory_key_padding_mask:
|
memory_key_padding_mask:
|
||||||
The padding mask from the encoder.
|
The padding mask from the encoder.
|
||||||
token_ids:
|
token_ids:
|
||||||
@ -264,11 +266,15 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||||
ys_in = [torch.tensor(y) for y in ys_in]
|
ys_in = [torch.tensor(y) for y in ys_in]
|
||||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
ys_in_pad = pad_sequence(
|
||||||
|
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||||
|
)
|
||||||
|
|
||||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||||
ys_out = [torch.tensor(y) for y in ys_out]
|
ys_out = [torch.tensor(y) for y in ys_out]
|
||||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
|
ys_out_pad = pad_sequence(
|
||||||
|
ys_out, batch_first=True, padding_value=float(-1)
|
||||||
|
)
|
||||||
|
|
||||||
device = memory.device
|
device = memory.device
|
||||||
ys_in_pad = ys_in_pad.to(device)
|
ys_in_pad = ys_in_pad.to(device)
|
||||||
@ -301,18 +307,19 @@ class Transformer(nn.Module):
|
|||||||
|
|
||||||
return decoder_loss
|
return decoder_loss
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
def decoder_nll(
|
def decoder_nll(
|
||||||
self,
|
self,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
memory_key_padding_mask: torch.Tensor,
|
memory_key_padding_mask: torch.Tensor,
|
||||||
token_ids: List[List[int]],
|
token_ids: List[torch.Tensor],
|
||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
memory:
|
memory:
|
||||||
It's the output of the encoder with shape [T, N, C]
|
It's the output of the encoder with shape (T, N, C)
|
||||||
memory_key_padding_mask:
|
memory_key_padding_mask:
|
||||||
The padding mask from the encoder.
|
The padding mask from the encoder.
|
||||||
token_ids:
|
token_ids:
|
||||||
@ -328,14 +335,23 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# The common part between this function and decoder_forward could be
|
# The common part between this function and decoder_forward could be
|
||||||
# extracted as a separate function.
|
# extracted as a separate function.
|
||||||
|
if isinstance(token_ids[0], torch.Tensor):
|
||||||
|
# This branch is executed by torchscript in C++.
|
||||||
|
# See https://github.com/k2-fsa/k2/pull/870
|
||||||
|
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
|
||||||
|
token_ids = [tolist(t) for t in token_ids]
|
||||||
|
|
||||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||||
ys_in = [torch.tensor(y) for y in ys_in]
|
ys_in = [torch.tensor(y) for y in ys_in]
|
||||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
ys_in_pad = pad_sequence(
|
||||||
|
ys_in, batch_first=True, padding_value=float(eos_id)
|
||||||
|
)
|
||||||
|
|
||||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||||
ys_out = [torch.tensor(y) for y in ys_out]
|
ys_out = [torch.tensor(y) for y in ys_out]
|
||||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
|
ys_out_pad = pad_sequence(
|
||||||
|
ys_out, batch_first=True, padding_value=float(-1)
|
||||||
|
)
|
||||||
|
|
||||||
device = memory.device
|
device = memory.device
|
||||||
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
||||||
@ -649,24 +665,24 @@ class PositionalEncoding(nn.Module):
|
|||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.xscale = math.sqrt(self.d_model)
|
self.xscale = math.sqrt(self.d_model)
|
||||||
self.dropout = nn.Dropout(p=dropout)
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
self.pe = None
|
# not doing: self.pe = None because of errors thrown by torchscript
|
||||||
|
self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
|
||||||
|
|
||||||
def extend_pe(self, x: torch.Tensor) -> None:
|
def extend_pe(self, x: torch.Tensor) -> None:
|
||||||
"""Extend the time t in the positional encoding if required.
|
"""Extend the time t in the positional encoding if required.
|
||||||
|
|
||||||
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x
|
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
|
||||||
is [N, T, d_model]. If T > T1, then we change the shape of self.pe
|
is (N, T, d_model). If T > T1, then we change the shape of self.pe
|
||||||
to [N, T, d_model]. Otherwise, nothing is done.
|
to (N, T, d_model). Otherwise, nothing is done.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
It is a tensor of shape [N, T, C].
|
It is a tensor of shape (N, T, C).
|
||||||
Returns:
|
Returns:
|
||||||
Return None.
|
Return None.
|
||||||
"""
|
"""
|
||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
if self.pe.size(1) >= x.size(1):
|
if self.pe.size(1) >= x.size(1):
|
||||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
|
||||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
return
|
return
|
||||||
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
|
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
|
||||||
@ -678,7 +694,7 @@ class PositionalEncoding(nn.Module):
|
|||||||
pe[:, 0::2] = torch.sin(position * div_term)
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
pe[:, 1::2] = torch.cos(position * div_term)
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
pe = pe.unsqueeze(0)
|
pe = pe.unsqueeze(0)
|
||||||
# Now pe is of shape [1, T, d_model], where T is x.size(1)
|
# Now pe is of shape (1, T, d_model), where T is x.size(1)
|
||||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -687,10 +703,10 @@ class PositionalEncoding(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
Its shape is [N, T, C]
|
Its shape is (N, T, C)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape [N, T, C]
|
Return a tensor of shape (N, T, C)
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x)
|
||||||
x = x * self.xscale + self.pe[:, : x.size(1), :]
|
x = x * self.xscale + self.pe[:, : x.size(1), :]
|
||||||
@ -784,73 +800,6 @@ class Noam(object):
|
|||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
class LabelSmoothingLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
Label-smoothing loss. KL-divergence between
|
|
||||||
q_{smoothed ground truth prob.}(w)
|
|
||||||
and p_{prob. computed by model}(w) is minimized.
|
|
||||||
Modified from
|
|
||||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
|
|
||||||
|
|
||||||
Args:
|
|
||||||
size: the number of class
|
|
||||||
padding_idx: padding_idx: ignored class id
|
|
||||||
smoothing: smoothing rate (0.0 means the conventional CE)
|
|
||||||
normalize_length: normalize loss by sequence length if True
|
|
||||||
criterion: loss function to be smoothed
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
size: int,
|
|
||||||
padding_idx: int = -1,
|
|
||||||
smoothing: float = 0.1,
|
|
||||||
normalize_length: bool = False,
|
|
||||||
criterion: nn.Module = nn.KLDivLoss(reduction="none"),
|
|
||||||
) -> None:
|
|
||||||
"""Construct an LabelSmoothingLoss object."""
|
|
||||||
super(LabelSmoothingLoss, self).__init__()
|
|
||||||
self.criterion = criterion
|
|
||||||
self.padding_idx = padding_idx
|
|
||||||
assert 0.0 < smoothing <= 1.0
|
|
||||||
self.confidence = 1.0 - smoothing
|
|
||||||
self.smoothing = smoothing
|
|
||||||
self.size = size
|
|
||||||
self.true_dist = None
|
|
||||||
self.normalize_length = normalize_length
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute loss between x and target.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x:
|
|
||||||
prediction of dimension
|
|
||||||
(batch_size, input_length, number_of_classes).
|
|
||||||
target:
|
|
||||||
target masked with self.padding_id of
|
|
||||||
dimension (batch_size, input_length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A scalar tensor containing the loss without normalization.
|
|
||||||
"""
|
|
||||||
assert x.size(2) == self.size
|
|
||||||
# batch_size = x.size(0)
|
|
||||||
x = x.view(-1, self.size)
|
|
||||||
target = target.view(-1)
|
|
||||||
with torch.no_grad():
|
|
||||||
true_dist = x.clone()
|
|
||||||
true_dist.fill_(self.smoothing / (self.size - 1))
|
|
||||||
ignore = target == self.padding_idx # (B,)
|
|
||||||
total = len(target) - ignore.sum().item()
|
|
||||||
target = target.masked_fill(ignore, 0) # avoid -1 index
|
|
||||||
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
|
||||||
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
|
||||||
# denom = total if self.normalize_length else batch_size
|
|
||||||
denom = total if self.normalize_length else 1
|
|
||||||
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
|
||||||
|
|
||||||
|
|
||||||
def encoder_padding_mask(
|
def encoder_padding_mask(
|
||||||
max_len: int, supervisions: Optional[Supervisions] = None
|
max_len: int, supervisions: Optional[Supervisions] = None
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
@ -972,10 +921,7 @@ def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
|
|||||||
Return a new list-of-list, where each sublist starts
|
Return a new list-of-list, where each sublist starts
|
||||||
with SOS ID.
|
with SOS ID.
|
||||||
"""
|
"""
|
||||||
ans = []
|
return [[sos_id] + utt for utt in token_ids]
|
||||||
for utt in token_ids:
|
|
||||||
ans.append([sos_id] + utt)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
||||||
@ -992,7 +938,9 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
|||||||
Return a new list-of-list, where each sublist ends
|
Return a new list-of-list, where each sublist ends
|
||||||
with EOS ID.
|
with EOS ID.
|
||||||
"""
|
"""
|
||||||
ans = []
|
return [utt + [eos_id] for utt in token_ids]
|
||||||
for utt in token_ids:
|
|
||||||
ans.append(utt + [eos_id])
|
|
||||||
return ans
|
def tolist(t: torch.Tensor) -> List[int]:
|
||||||
|
"""Used by jit"""
|
||||||
|
return torch.jit.annotate(List[int], t.tolist())
|
||||||
|
@ -28,12 +28,12 @@ from conformer import Conformer
|
|||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.decode import one_best_decoding
|
from icefall.decode import one_best_decoding
|
||||||
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
get_alignments,
|
get_alignments,
|
||||||
get_env_info,
|
|
||||||
save_alignments,
|
save_alignments,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
)
|
)
|
||||||
|
@ -40,14 +40,13 @@ from icefall.decode import (
|
|||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_env_info,
|
|
||||||
get_texts,
|
get_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -122,17 +121,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--export",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""When enabled, the averaged model is saved to
|
|
||||||
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
|
||||||
pretrained.pt contains a dict {"model": model.state_dict()},
|
|
||||||
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -671,13 +659,6 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
if params.export:
|
|
||||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
|
||||||
torch.save(
|
|
||||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
@ -36,7 +36,7 @@ from icefall.decode import (
|
|||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -256,7 +256,6 @@ def main():
|
|||||||
params.num_decoder_layers = 0
|
params.num_decoder_layers = 0
|
||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params["env_info"] = get_env_info()
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -41,12 +41,12 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
|||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
get_env_info,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
@ -36,10 +36,10 @@ from icefall.decode import (
|
|||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_env_info,
|
|
||||||
get_texts,
|
get_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
|
@ -34,7 +34,7 @@ from icefall.decode import (
|
|||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -159,7 +159,6 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params["env_info"] = get_env_info()
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
from icefall.env import get_env_info
|
||||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
get_env_info,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
@ -34,7 +34,7 @@ from icefall.decode import (
|
|||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -159,7 +159,6 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params["env_info"] = get_env_info()
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
from icefall.env import get_env_info
|
||||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
get_env_info,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
@ -34,7 +34,7 @@ from icefall.decode import (
|
|||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -159,7 +159,6 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params["env_info"] = get_env_info()
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
from icefall.env import get_env_info
|
||||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
get_env_info,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
@ -14,10 +14,10 @@ from model import Tdnn
|
|||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.decode import get_lattice, one_best_decoding
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_env_info,
|
|
||||||
get_texts,
|
get_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
|
@ -29,7 +29,7 @@ from model import Tdnn
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.decode import get_lattice, one_best_decoding
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
from icefall.utils import AttributeDict, get_env_info, get_texts
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -116,7 +116,6 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params["env_info"] = get_env_info()
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -22,15 +22,10 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
from icefall.env import get_env_info
|
||||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
AttributeDict,
|
|
||||||
MetricsTracker,
|
|
||||||
get_env_info,
|
|
||||||
setup_logger,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
106
icefall/env.py
Normal file
106
icefall/env.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
|
# Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import k2.version
|
||||||
|
import lhotse
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_sha1():
|
||||||
|
git_commit = (
|
||||||
|
subprocess.run(
|
||||||
|
["git", "rev-parse", "--short", "HEAD"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
.stdout.decode()
|
||||||
|
.rstrip("\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
dirty_commit = (
|
||||||
|
len(
|
||||||
|
subprocess.run(
|
||||||
|
["git", "diff", "--shortstat"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
.stdout.decode()
|
||||||
|
.rstrip("\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
> 0
|
||||||
|
)
|
||||||
|
git_commit = (
|
||||||
|
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
|
||||||
|
)
|
||||||
|
return git_commit
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_date():
|
||||||
|
git_date = (
|
||||||
|
subprocess.run(
|
||||||
|
["git", "log", "-1", "--format=%ad", "--date=local"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
.stdout.decode()
|
||||||
|
.rstrip("\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
return git_date
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_branch_name():
|
||||||
|
git_date = (
|
||||||
|
subprocess.run(
|
||||||
|
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
.stdout.decode()
|
||||||
|
.rstrip("\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
return git_date
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_info() -> Dict[str, Any]:
|
||||||
|
"""Get the environment information."""
|
||||||
|
return {
|
||||||
|
"k2-version": k2.version.__version__,
|
||||||
|
"k2-build-type": k2.version.__build_type__,
|
||||||
|
"k2-with-cuda": k2.with_cuda,
|
||||||
|
"k2-git-sha1": k2.version.__git_sha1__,
|
||||||
|
"k2-git-date": k2.version.__git_date__,
|
||||||
|
"lhotse-version": lhotse.__version__,
|
||||||
|
"torch-cuda-available": torch.cuda.is_available(),
|
||||||
|
"torch-cuda-version": torch.version.cuda,
|
||||||
|
"python-version": sys.version[:3],
|
||||||
|
"icefall-git-branch": get_git_branch_name(),
|
||||||
|
"icefall-git-sha1": get_git_sha1(),
|
||||||
|
"icefall-git-date": get_git_date(),
|
||||||
|
"icefall-path": str(Path(__file__).resolve().parent.parent),
|
||||||
|
"k2-path": str(Path(k2.__file__).resolve()),
|
||||||
|
"lhotse-path": str(Path(lhotse.__file__).resolve()),
|
||||||
|
}
|
@ -21,17 +21,15 @@ import collections
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union
|
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import k2.version
|
import k2.version
|
||||||
import kaldialign
|
import kaldialign
|
||||||
import lhotse
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@ -137,85 +135,6 @@ def setup_logger(
|
|||||||
logging.getLogger("").addHandler(console)
|
logging.getLogger("").addHandler(console)
|
||||||
|
|
||||||
|
|
||||||
def get_git_sha1():
|
|
||||||
git_commit = (
|
|
||||||
subprocess.run(
|
|
||||||
["git", "rev-parse", "--short", "HEAD"],
|
|
||||||
check=True,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
)
|
|
||||||
.stdout.decode()
|
|
||||||
.rstrip("\n")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
dirty_commit = (
|
|
||||||
len(
|
|
||||||
subprocess.run(
|
|
||||||
["git", "diff", "--shortstat"],
|
|
||||||
check=True,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
)
|
|
||||||
.stdout.decode()
|
|
||||||
.rstrip("\n")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
> 0
|
|
||||||
)
|
|
||||||
git_commit = (
|
|
||||||
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
|
|
||||||
)
|
|
||||||
return git_commit
|
|
||||||
|
|
||||||
|
|
||||||
def get_git_date():
|
|
||||||
git_date = (
|
|
||||||
subprocess.run(
|
|
||||||
["git", "log", "-1", "--format=%ad", "--date=local"],
|
|
||||||
check=True,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
)
|
|
||||||
.stdout.decode()
|
|
||||||
.rstrip("\n")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
return git_date
|
|
||||||
|
|
||||||
|
|
||||||
def get_git_branch_name():
|
|
||||||
git_date = (
|
|
||||||
subprocess.run(
|
|
||||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
|
||||||
check=True,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
)
|
|
||||||
.stdout.decode()
|
|
||||||
.rstrip("\n")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
return git_date
|
|
||||||
|
|
||||||
|
|
||||||
def get_env_info() -> Dict[str, Any]:
|
|
||||||
"""Get the environment information."""
|
|
||||||
return {
|
|
||||||
"k2-version": k2.version.__version__,
|
|
||||||
"k2-build-type": k2.version.__build_type__,
|
|
||||||
"k2-with-cuda": k2.with_cuda,
|
|
||||||
"k2-git-sha1": k2.version.__git_sha1__,
|
|
||||||
"k2-git-date": k2.version.__git_date__,
|
|
||||||
"lhotse-version": lhotse.__version__,
|
|
||||||
"torch-cuda-available": torch.cuda.is_available(),
|
|
||||||
"torch-cuda-version": torch.version.cuda,
|
|
||||||
"python-version": sys.version[:3],
|
|
||||||
"icefall-git-branch": get_git_branch_name(),
|
|
||||||
"icefall-git-sha1": get_git_sha1(),
|
|
||||||
"icefall-git-date": get_git_date(),
|
|
||||||
"icefall-path": str(Path(__file__).resolve().parent.parent),
|
|
||||||
"k2-path": str(Path(k2.__file__).resolve()),
|
|
||||||
"lhotse-path": str(Path(lhotse.__file__).resolve()),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AttributeDict(dict):
|
class AttributeDict(dict):
|
||||||
def __getattr__(self, key):
|
def __getattr__(self, key):
|
||||||
if key in self:
|
if key in self:
|
||||||
|
@ -20,12 +20,8 @@ import k2
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from icefall.utils import (
|
from icefall.env import get_env_info
|
||||||
AttributeDict,
|
from icefall.utils import AttributeDict, encode_supervisions, get_texts
|
||||||
encode_supervisions,
|
|
||||||
get_env_info,
|
|
||||||
get_texts,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
Loading…
x
Reference in New Issue
Block a user