mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
CTC attention model with reworked Conformer encoder and reworked Transformer decoder (#462)
* ctc attention model with reworked conformer encoder and reworked transformer decoder * remove unnecessary func * resolve flake8 conflicts * fix typos and modify the expr of ScaledEmbedding * use original beam size * minor changes to the scripts * add rnn lm decoding * minor changes * check whether q k v weight is None * check whether q k v weight is None * check whether q k v weight is None * style correction * update results * update results * upload the decoding results of rnn-lm to the RESULTS * upload the decoding results of rnn-lm to the RESULTS * Update egs/librispeech/ASR/RESULTS.md Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> * Update egs/librispeech/ASR/RESULTS.md Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> * Update egs/librispeech/ASR/RESULTS.md Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
parent
3d2986b4c2
commit
116d0cf26d
6
.flake8
6
.flake8
@ -4,12 +4,14 @@ statistics=true
|
|||||||
max-line-length = 80
|
max-line-length = 80
|
||||||
per-file-ignores =
|
per-file-ignores =
|
||||||
# line too long
|
# line too long
|
||||||
icefall/diagnostics.py: E501
|
icefall/diagnostics.py: E501,
|
||||||
egs/*/ASR/*/conformer.py: E501,
|
egs/*/ASR/*/conformer.py: E501,
|
||||||
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
|
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
|
||||||
egs/*/ASR/*/optim.py: E501,
|
egs/*/ASR/*/optim.py: E501,
|
||||||
egs/*/ASR/*/scaling.py: E501,
|
egs/*/ASR/*/scaling.py: E501,
|
||||||
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
|
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203,
|
||||||
|
egs/librispeech/ASR/conformer_ctc2/*py: E501,
|
||||||
|
egs/librispeech/ASR/RESULTS.md: E999,
|
||||||
|
|
||||||
# invalid escape sequence (cause by tex formular), W605
|
# invalid escape sequence (cause by tex formular), W605
|
||||||
icefall/utils.py: E501, W605
|
icefall/utils.py: E501, W605
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
## Results
|
## Results
|
||||||
|
|
||||||
### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T 2)
|
#### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T 2)
|
||||||
|
|
||||||
[conv_emformer_transducer_stateless2](./conv_emformer_transducer_stateless2)
|
[conv_emformer_transducer_stateless2](./conv_emformer_transducer_stateless2)
|
||||||
|
|
||||||
@ -1998,6 +1998,118 @@ avg=11
|
|||||||
You can find the tensorboard log at: <https://tensorboard.dev/experiment/D7NQc3xqTpyVmWi5FnWjrA>
|
You can find the tensorboard log at: <https://tensorboard.dev/experiment/D7NQc3xqTpyVmWi5FnWjrA>
|
||||||
|
|
||||||
|
|
||||||
|
### LibriSpeech BPE training results (Conformer-CTC 2)
|
||||||
|
|
||||||
|
#### [conformer_ctc2](./conformer_ctc2)
|
||||||
|
|
||||||
|
#### 2022-07-21
|
||||||
|
|
||||||
|
It implements a 'reworked' version of CTC attention model.
|
||||||
|
As demenstrated by pruned_transducer_stateless2, reworked Conformer model has superior performance compared to the original Conformer.
|
||||||
|
So in this modified version of CTC attention model, it has the reworked Conformer as the encoder and the reworked Transformer as the decoder.
|
||||||
|
conformer_ctc2 also integrates with the idea of the 'averaging models' in pruned_transducer_stateless4.
|
||||||
|
|
||||||
|
The WERs on comparisons with a baseline model, for the librispeech test dataset, are listed as below.
|
||||||
|
|
||||||
|
The baseline model is the original conformer CTC attention model trained with icefall/egs/librispeech/ASR/conformer_ctc.
|
||||||
|
The model is downloaded from <https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09>.
|
||||||
|
This model has 12 layers of Conformer encoder layers and 6 Transformer decoder layers.
|
||||||
|
Number of model parameters is 109,226,120.
|
||||||
|
It has been trained with 90 epochs with full Librispeech dataset.
|
||||||
|
|
||||||
|
For this reworked CTC attention model, it has 12 layers of reworked Conformer encoder layers and 6 reworked Transformer decoder layers.
|
||||||
|
Number of model parameters is 103,071,035.
|
||||||
|
With full Librispeech data set, it was trained for **only** 30 epochs because the reworked model would converge much faster.
|
||||||
|
Please refer to <https://tensorboard.dev/experiment/GR1s6VrJRTW5rtB50jakew/#scalars> to see the loss convergence curve.
|
||||||
|
Please find the above trained model at <https://huggingface.co/WayneWiser/icefall-asr-librispeech-conformer-ctc2-jit-bpe-500-2022-07-21> in huggingface.
|
||||||
|
|
||||||
|
The decoding configuration for the reworked model is --epoch 30, --avg 8, --use-averaged-model True, which is the best after searching.
|
||||||
|
|
||||||
|
| WER | reworked ctc attention | with --epoch 30 --avg 8 --use-averaged-model True | | ctc attention| with --epoch 77 --avg 55 | |
|
||||||
|
|------------------------|-------|------|------|------|------|-----|
|
||||||
|
| test sets | test-clean | test-other | Avg | test-clean | test-other | Avg |
|
||||||
|
| ctc-greedy-search | 2.98% | 7.14%| 5.06%| 2.90%| 7.47%| 5.19%|
|
||||||
|
| ctc-decoding | 2.98% | 7.14%| 5.06%| 2.90%| 7.47%| 5.19%|
|
||||||
|
| 1best | 2.93% | 6.37%| 4.65%| 2.70%| 6.49%| 4.60%|
|
||||||
|
| nbest | 2.94% | 6.39%| 4.67%| 2.70%| 6.48%| 4.59%|
|
||||||
|
| nbest-rescoring | 2.68% | 5.77%| 4.23%| 2.55%| 6.07%| 4.31%|
|
||||||
|
| whole-lattice-rescoring| 2.66% | 5.76%| 4.21%| 2.56%| 6.04%| 4.30%|
|
||||||
|
| attention-decoder | 2.59% | 5.54%| 4.07%| 2.41%| 5.77%| 4.09%|
|
||||||
|
| nbest-oracle | 1.53% | 3.47%| 2.50%| 1.69%| 4.02%| 2.86%|
|
||||||
|
| rnn-lm | 2.37% | 4.98%| 3.68%| 2.31%| 5.35%| 3.83%|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
conformer_ctc2 also implements the CTC greedy search decoding, it has the identical WERs with the CTC-decoding method.
|
||||||
|
For other decoding methods, the average WER of the two test sets with the two models is similar.
|
||||||
|
Except for the 1best and nbest methods, the overall performance of reworked model is better than the baseline model.
|
||||||
|
|
||||||
|
|
||||||
|
To reproduce the above result, use the following commands.
|
||||||
|
|
||||||
|
The training commands are:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
WORLD_SIZE=8
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||||
|
./conformer_ctc2/train.py \
|
||||||
|
--manifest-dir data/fbank \
|
||||||
|
--exp-dir conformer_ctc2/exp \
|
||||||
|
--full-libri 1 \
|
||||||
|
--spec-aug-time-warp-factor 80 \
|
||||||
|
--max-duration 300 \
|
||||||
|
--world-size ${WORLD_SIZE} \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--att-rate 0.7 \
|
||||||
|
--num-decoder-layers 6
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
And the following commands are for decoding:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
|
||||||
|
|
||||||
|
for method in ctc-greedy-search ctc-decoding 1best nbest-oracle; do
|
||||||
|
python3 ./conformer_ctc2/decode.py \
|
||||||
|
--exp-dir conformer_ctc2/exp \
|
||||||
|
--use-averaged-model True --epoch 30 --avg 8 --max-duration 200 --method $method
|
||||||
|
done
|
||||||
|
|
||||||
|
for method in nbest nbest-rescoring whole-lattice-rescoring attention-decoder ; do
|
||||||
|
python3 ./conformer_ctc2/decode.py \
|
||||||
|
--exp-dir conformer_ctc2/exp \
|
||||||
|
--use-averaged-model True --epoch 30 --avg 8 --max-duration 20 --method $method
|
||||||
|
done
|
||||||
|
|
||||||
|
rnn_dir=$(git rev-parse --show-toplevel)/icefall/rnn_lm
|
||||||
|
./conformer_ctc2/decode.py \
|
||||||
|
--exp-dir conformer_ctc2/exp \
|
||||||
|
--lang-dir data/lang_bpe_500 \
|
||||||
|
--lm-dir data/lm \
|
||||||
|
--max-duration 30 \
|
||||||
|
--concatenate-cuts 0 \
|
||||||
|
--bucketing-sampler 1 \
|
||||||
|
--num-paths 1000 \
|
||||||
|
--use-averaged-model True \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 8 \
|
||||||
|
--nbest-scale 0.5 \
|
||||||
|
--rnn-lm-exp-dir ${rnn_dir}/exp \
|
||||||
|
--rnn-lm-epoch 29 \
|
||||||
|
--rnn-lm-avg 3 \
|
||||||
|
--rnn-lm-embedding-dim 2048 \
|
||||||
|
--rnn-lm-hidden-dim 2048 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--rnn-lm-tie-weights true \
|
||||||
|
--method rnn-lm
|
||||||
|
```
|
||||||
|
|
||||||
|
You can find the RNN-LM pre-trained model at
|
||||||
|
<https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
|
||||||
|
|
||||||
|
|
||||||
### LibriSpeech BPE training results (Conformer-CTC)
|
### LibriSpeech BPE training results (Conformer-CTC)
|
||||||
|
|
||||||
#### 2021-11-09
|
#### 2021-11-09
|
||||||
|
1
egs/librispeech/ASR/conformer_ctc2/__init__.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc2/__init__.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/__init__.py
|
1
egs/librispeech/ASR/conformer_ctc2/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc2/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/asr_datamodule.py
|
252
egs/librispeech/ASR/conformer_ctc2/attention.py
Normal file
252
egs/librispeech/ASR/conformer_ctc2/attention.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corp. (author: Quandong Wang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn.init import xavier_normal_
|
||||||
|
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(nn.Module):
|
||||||
|
r"""Allows the model to jointly attend to information
|
||||||
|
from different representation subspaces.
|
||||||
|
See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||||||
|
|
||||||
|
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim: Total dimension of the model.
|
||||||
|
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
||||||
|
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
||||||
|
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
||||||
|
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
||||||
|
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
||||||
|
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
||||||
|
Default: ``False``.
|
||||||
|
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
||||||
|
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
||||||
|
batch_first: If ``True``, then the input and output tensors are provided
|
||||||
|
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||||
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||||
|
"""
|
||||||
|
__constants__ = ["batch_first"]
|
||||||
|
bias_k: Optional[torch.Tensor]
|
||||||
|
bias_v: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=True,
|
||||||
|
add_bias_kv=False,
|
||||||
|
add_zero_attn=False,
|
||||||
|
kdim=None,
|
||||||
|
vdim=None,
|
||||||
|
batch_first=False,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super(MultiheadAttention, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.kdim = kdim if kdim is not None else embed_dim
|
||||||
|
self.vdim = vdim if vdim is not None else embed_dim
|
||||||
|
self._qkv_same_embed_dim = (
|
||||||
|
self.kdim == embed_dim and self.vdim == embed_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.batch_first = batch_first
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
assert (
|
||||||
|
self.head_dim * num_heads == self.embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
|
if self._qkv_same_embed_dim is False:
|
||||||
|
self.q_proj_weight = ScaledLinear(embed_dim, embed_dim, bias=bias)
|
||||||
|
self.k_proj_weight = ScaledLinear(self.kdim, embed_dim, bias=bias)
|
||||||
|
self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias)
|
||||||
|
self.register_parameter("in_proj_weight", None)
|
||||||
|
else:
|
||||||
|
self.in_proj_weight = ScaledLinear(
|
||||||
|
embed_dim, 3 * embed_dim, bias=bias
|
||||||
|
)
|
||||||
|
self.register_parameter("q_proj_weight", None)
|
||||||
|
self.register_parameter("k_proj_weight", None)
|
||||||
|
self.register_parameter("v_proj_weight", None)
|
||||||
|
|
||||||
|
if not bias:
|
||||||
|
self.register_parameter("in_proj_bias", None)
|
||||||
|
|
||||||
|
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias)
|
||||||
|
|
||||||
|
if add_bias_kv:
|
||||||
|
self.bias_k = nn.Parameter(
|
||||||
|
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.bias_v = nn.Parameter(
|
||||||
|
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.bias_k = self.bias_v = None
|
||||||
|
|
||||||
|
self.add_zero_attn = add_zero_attn
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
if self.bias_k is not None:
|
||||||
|
xavier_normal_(self.bias_k)
|
||||||
|
if self.bias_v is not None:
|
||||||
|
xavier_normal_(self.bias_v)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
||||||
|
if "_qkv_same_embed_dim" not in state:
|
||||||
|
state["_qkv_same_embed_dim"] = True
|
||||||
|
|
||||||
|
super(MultiheadAttention, self).__setstate__(state)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
query: Query embeddings of shape :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)`
|
||||||
|
when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size,
|
||||||
|
and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against
|
||||||
|
key-value pairs to produce the output. See "Attention Is All You Need" for more details.
|
||||||
|
key: Key embeddings of shape :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` when
|
||||||
|
``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and
|
||||||
|
:math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details.
|
||||||
|
value: Value embeddings of shape :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` when
|
||||||
|
``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and
|
||||||
|
:math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details.
|
||||||
|
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
||||||
|
to ignore for the purpose of attention (i.e. treat as "padding"). Binary and byte masks are supported.
|
||||||
|
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
||||||
|
the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key``
|
||||||
|
value will be ignored.
|
||||||
|
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
||||||
|
Default: ``True``.
|
||||||
|
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
||||||
|
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
||||||
|
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
||||||
|
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
||||||
|
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
||||||
|
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
|
||||||
|
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
||||||
|
the attention weight.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- **attn_output** - Attention outputs of shape :math:`(L, N, E)` when ``batch_first=False`` or
|
||||||
|
:math:`(N, L, E)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is
|
||||||
|
the batch size, and :math:`E` is the embedding dimension ``embed_dim``.
|
||||||
|
- **attn_output_weights** - Attention output weights of shape :math:`(N, L, S)`, where :math:`N` is the batch
|
||||||
|
size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. Only returned
|
||||||
|
when ``need_weights=True``.
|
||||||
|
"""
|
||||||
|
if self.batch_first:
|
||||||
|
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
||||||
|
|
||||||
|
if not self._qkv_same_embed_dim:
|
||||||
|
q_proj_weight = (
|
||||||
|
self.q_proj_weight.get_weight()
|
||||||
|
if self.q_proj_weight is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
k_proj_weight = (
|
||||||
|
self.k_proj_weight.get_weight()
|
||||||
|
if self.k_proj_weight is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
v_proj_weight = (
|
||||||
|
self.v_proj_weight.get_weight()
|
||||||
|
if self.v_proj_weight is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
(
|
||||||
|
attn_output,
|
||||||
|
attn_output_weights,
|
||||||
|
) = nn.functional.multi_head_attention_forward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
self.embed_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.in_proj_weight.get_weight(),
|
||||||
|
self.in_proj_weight.get_bias(),
|
||||||
|
self.bias_k,
|
||||||
|
self.bias_v,
|
||||||
|
self.add_zero_attn,
|
||||||
|
self.dropout,
|
||||||
|
self.out_proj.get_weight(),
|
||||||
|
self.out_proj.get_bias(),
|
||||||
|
training=self.training,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
use_separate_proj_weight=True,
|
||||||
|
q_proj_weight=q_proj_weight,
|
||||||
|
k_proj_weight=k_proj_weight,
|
||||||
|
v_proj_weight=v_proj_weight,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
attn_output,
|
||||||
|
attn_output_weights,
|
||||||
|
) = nn.functional.multi_head_attention_forward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
self.embed_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.in_proj_weight.get_weight(),
|
||||||
|
self.in_proj_weight.get_bias(),
|
||||||
|
self.bias_k,
|
||||||
|
self.bias_v,
|
||||||
|
self.add_zero_attn,
|
||||||
|
self.dropout,
|
||||||
|
self.out_proj.get_weight(),
|
||||||
|
self.out_proj.get_bias(),
|
||||||
|
training=self.training,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
)
|
||||||
|
if self.batch_first:
|
||||||
|
return attn_output.transpose(1, 0), attn_output_weights
|
||||||
|
else:
|
||||||
|
return attn_output, attn_output_weights
|
964
egs/librispeech/ASR/conformer_ctc2/conformer.py
Normal file
964
egs/librispeech/ASR/conformer_ctc2/conformer.py
Normal file
@ -0,0 +1,964 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||||
|
# 2022 Xiaomi Corp. (author: Quandong Wang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling import (
|
||||||
|
ActivationBalancer,
|
||||||
|
BasicNorm,
|
||||||
|
DoubleSwish,
|
||||||
|
ScaledConv1d,
|
||||||
|
ScaledLinear,
|
||||||
|
)
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from subsampling import Conv2dSubsampling
|
||||||
|
|
||||||
|
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||||
|
|
||||||
|
|
||||||
|
class Conformer(Transformer):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_features (int): Number of input features
|
||||||
|
num_classes (int): Number of output classes
|
||||||
|
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||||
|
d_model (int): attention dimension, also the output dimension
|
||||||
|
nhead (int): number of head
|
||||||
|
dim_feedforward (int): feedforward dimention
|
||||||
|
num_encoder_layers (int): number of encoder layers
|
||||||
|
num_decoder_layers (int): number of decoder layers
|
||||||
|
dropout (float): dropout rate
|
||||||
|
layer_dropout (float): layer-dropout rate.
|
||||||
|
cnn_module_kernel (int): Kernel size of convolution module
|
||||||
|
vgg_frontend (bool): whether to use vgg frontend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_features: int,
|
||||||
|
num_classes: int,
|
||||||
|
subsampling_factor: int = 4,
|
||||||
|
d_model: int = 256,
|
||||||
|
nhead: int = 4,
|
||||||
|
dim_feedforward: int = 2048,
|
||||||
|
num_encoder_layers: int = 12,
|
||||||
|
num_decoder_layers: int = 6,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
layer_dropout: float = 0.075,
|
||||||
|
cnn_module_kernel: int = 31,
|
||||||
|
) -> None:
|
||||||
|
super(Conformer, self).__init__(
|
||||||
|
num_features=num_features,
|
||||||
|
num_classes=num_classes,
|
||||||
|
subsampling_factor=subsampling_factor,
|
||||||
|
d_model=d_model,
|
||||||
|
nhead=nhead,
|
||||||
|
dim_feedforward=dim_feedforward,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
num_decoder_layers=num_decoder_layers,
|
||||||
|
dropout=dropout,
|
||||||
|
layer_dropout=layer_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_features = num_features
|
||||||
|
self.subsampling_factor = subsampling_factor
|
||||||
|
if subsampling_factor != 4:
|
||||||
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||||
|
|
||||||
|
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||||
|
# to the shape (N, T//subsampling_factor, d_model).
|
||||||
|
# That is, it does two things simultaneously:
|
||||||
|
# (1) subsampling: T -> T//subsampling_factor
|
||||||
|
# (2) embedding: num_features -> d_model
|
||||||
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||||
|
|
||||||
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
|
encoder_layer = ConformerEncoderLayer(
|
||||||
|
d_model,
|
||||||
|
nhead,
|
||||||
|
dim_feedforward,
|
||||||
|
dropout,
|
||||||
|
layer_dropout,
|
||||||
|
cnn_module_kernel,
|
||||||
|
)
|
||||||
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||||
|
|
||||||
|
def run_encoder(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
supervisions: Optional[Supervisions] = None,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||||
|
supervisions:
|
||||||
|
Supervision in lhotse format.
|
||||||
|
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||||
|
CAUTION: It contains length information, i.e., start and number of
|
||||||
|
frames, before subsampling
|
||||||
|
It is read directly from the batch, without any sorting. It is used
|
||||||
|
to compute encoder padding mask, which is used as memory key padding
|
||||||
|
mask for the decoder.
|
||||||
|
warmup:
|
||||||
|
A floating point value that gradually increases from 0 throughout
|
||||||
|
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||||
|
to turn modules on sequentially.
|
||||||
|
Returns:
|
||||||
|
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
||||||
|
Tensor: Mask tensor of dimension (batch_size, input_length)
|
||||||
|
"""
|
||||||
|
x = self.encoder_embed(x)
|
||||||
|
x, pos_emb = self.encoder_pos(x)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
mask = encoder_padding_mask(x.size(0), supervisions)
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.to(x.device)
|
||||||
|
|
||||||
|
# Caution: We assume the subsampling factor is 4!
|
||||||
|
|
||||||
|
x = self.encoder(
|
||||||
|
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
|
# x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
|
# return x, lengths
|
||||||
|
return x, mask
|
||||||
|
|
||||||
|
|
||||||
|
class ConformerEncoderLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
||||||
|
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model: the number of expected features in the input (required).
|
||||||
|
nhead: the number of heads in the multiheadattention models (required).
|
||||||
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||||
|
dropout: the dropout value (default=0.1).
|
||||||
|
cnn_module_kernel (int): Kernel size of convolution module.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||||
|
>>> src = torch.rand(10, 32, 512)
|
||||||
|
>>> pos_emb = torch.rand(32, 19, 512)
|
||||||
|
>>> out = encoder_layer(src, pos_emb)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
nhead: int,
|
||||||
|
dim_feedforward: int = 2048,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
layer_dropout: float = 0.075,
|
||||||
|
cnn_module_kernel: int = 31,
|
||||||
|
) -> None:
|
||||||
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
|
|
||||||
|
self.layer_dropout = layer_dropout
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
|
||||||
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
|
d_model, nhead, dropout=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
self.feed_forward = nn.Sequential(
|
||||||
|
ScaledLinear(d_model, dim_feedforward),
|
||||||
|
ActivationBalancer(channel_dim=-1),
|
||||||
|
DoubleSwish(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
|
ScaledLinear(d_model, dim_feedforward),
|
||||||
|
ActivationBalancer(channel_dim=-1),
|
||||||
|
DoubleSwish(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
|
self.norm_final = BasicNorm(d_model)
|
||||||
|
|
||||||
|
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||||
|
self.balancer = ActivationBalancer(
|
||||||
|
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
src_mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder layer (required).
|
||||||
|
pos_emb: Positional embedding tensor (required).
|
||||||
|
src_mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||||
|
bypass layers more frequently.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
src: (S, N, E).
|
||||||
|
pos_emb: (N, 2*S-1, E)
|
||||||
|
src_mask: (S, S).
|
||||||
|
src_key_padding_mask: (N, S).
|
||||||
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
|
"""
|
||||||
|
src_orig = src
|
||||||
|
|
||||||
|
warmup_scale = min(0.1 + warmup, 1.0)
|
||||||
|
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||||
|
# completely bypass it.
|
||||||
|
if self.training:
|
||||||
|
alpha = (
|
||||||
|
warmup_scale
|
||||||
|
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
|
||||||
|
else 0.1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
# macaron style feed forward module
|
||||||
|
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||||
|
|
||||||
|
# multi-headed self-attention module
|
||||||
|
src_att = self.self_attn(
|
||||||
|
src,
|
||||||
|
src,
|
||||||
|
src,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
attn_mask=src_mask,
|
||||||
|
key_padding_mask=src_key_padding_mask,
|
||||||
|
)[0]
|
||||||
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
|
# convolution module
|
||||||
|
src = src + self.dropout(self.conv_module(src))
|
||||||
|
|
||||||
|
# feed forward module
|
||||||
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
|
|
||||||
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
|
if alpha != 1.0:
|
||||||
|
src = alpha * src + (1 - alpha) * src_orig
|
||||||
|
|
||||||
|
return src
|
||||||
|
|
||||||
|
|
||||||
|
class ConformerEncoder(nn.Module):
|
||||||
|
r"""ConformerEncoder is a stack of N encoder layers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
|
||||||
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||||
|
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
|
||||||
|
>>> src = torch.rand(10, 32, 512)
|
||||||
|
>>> pos_emb = torch.rand(32, 19, 512)
|
||||||
|
>>> out = conformer_encoder(src, pos_emb)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
|
)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder (required).
|
||||||
|
pos_emb: Positional embedding tensor (required).
|
||||||
|
mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
src: (S, N, E).
|
||||||
|
pos_emb: (N, 2*S-1, E)
|
||||||
|
mask: (S, S).
|
||||||
|
src_key_padding_mask: (N, S).
|
||||||
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
|
|
||||||
|
"""
|
||||||
|
output = src
|
||||||
|
|
||||||
|
for i, mod in enumerate(self.layers):
|
||||||
|
output = mod(
|
||||||
|
output,
|
||||||
|
pos_emb,
|
||||||
|
src_mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
warmup=warmup,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class RelPositionalEncoding(torch.nn.Module):
|
||||||
|
"""Relative positional encoding module.
|
||||||
|
|
||||||
|
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||||
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model: Embedding dimension.
|
||||||
|
dropout_rate: Dropout rate.
|
||||||
|
max_len: Maximum input length.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||||
|
) -> None:
|
||||||
|
"""Construct an PositionalEncoding object."""
|
||||||
|
super(RelPositionalEncoding, self).__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||||
|
self.pe = None
|
||||||
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
|
def extend_pe(self, x: Tensor) -> None:
|
||||||
|
"""Reset the positional encodings."""
|
||||||
|
if self.pe is not None:
|
||||||
|
# self.pe contains both positive and negative parts
|
||||||
|
# the length of self.pe is 2 * input_len - 1
|
||||||
|
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||||
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||||
|
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||||
|
x.device
|
||||||
|
):
|
||||||
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
|
return
|
||||||
|
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||||
|
# position of key vector. We use position relative positions when keys
|
||||||
|
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||||
|
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||||
|
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||||
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||||
|
div_term = torch.exp(
|
||||||
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||||
|
* -(math.log(10000.0) / self.d_model)
|
||||||
|
)
|
||||||
|
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||||
|
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||||
|
|
||||||
|
# Reserve the order of positive indices and concat both positive and
|
||||||
|
# negative indices. This is used to support the shifting trick
|
||||||
|
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||||
|
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||||
|
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||||
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||||
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""Add positional encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||||
|
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.extend_pe(x)
|
||||||
|
pos_emb = self.pe[
|
||||||
|
:,
|
||||||
|
self.pe.size(1) // 2
|
||||||
|
- x.size(1)
|
||||||
|
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||||
|
+ x.size(1),
|
||||||
|
]
|
||||||
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
|
|
||||||
|
class RelPositionMultiheadAttention(nn.Module):
|
||||||
|
r"""Multi-Head Attention layer with relative position encoding
|
||||||
|
|
||||||
|
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim: total dimension of the model.
|
||||||
|
num_heads: parallel attention heads.
|
||||||
|
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
||||||
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
super(RelPositionMultiheadAttention, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
assert (
|
||||||
|
self.head_dim * num_heads == self.embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
|
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
|
self.out_proj = ScaledLinear(
|
||||||
|
embed_dim, embed_dim, bias=True, initial_scale=0.25
|
||||||
|
)
|
||||||
|
|
||||||
|
# linear transformation for positional encoding.
|
||||||
|
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
||||||
|
# these two learnable bias are used in matrix c and matrix d
|
||||||
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
|
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
|
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
|
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
|
||||||
|
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _pos_bias_u(self):
|
||||||
|
return self.pos_bias_u * self.pos_bias_u_scale.exp()
|
||||||
|
|
||||||
|
def _pos_bias_v(self):
|
||||||
|
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
||||||
|
|
||||||
|
def _reset_parameters(self) -> None:
|
||||||
|
nn.init.normal_(self.pos_bias_u, std=0.01)
|
||||||
|
nn.init.normal_(self.pos_bias_v, std=0.01)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
query, key, value: map a query and a set of key-value pairs to an output.
|
||||||
|
pos_emb: Positional embedding tensor
|
||||||
|
key_padding_mask: if provided, specified padding elements in the key will
|
||||||
|
be ignored by the attention. When given a binary mask and a value is True,
|
||||||
|
the corresponding value on the attention layer will be ignored. When given
|
||||||
|
a byte mask and a value is non-zero, the corresponding value on the attention
|
||||||
|
layer will be ignored
|
||||||
|
need_weights: output attn_output_weights.
|
||||||
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Inputs:
|
||||||
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||||
|
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||||
|
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||||
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||||
|
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||||
|
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||||
|
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
||||||
|
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||||
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||||
|
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||||
|
is provided, it will be added to the attention weight.
|
||||||
|
|
||||||
|
- Outputs:
|
||||||
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||||
|
E is the embedding dimension.
|
||||||
|
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||||
|
L is the target sequence length, S is the source sequence length.
|
||||||
|
"""
|
||||||
|
return self.multi_head_attention_forward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
pos_emb,
|
||||||
|
self.embed_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.in_proj.get_weight(),
|
||||||
|
self.in_proj.get_bias(),
|
||||||
|
self.dropout,
|
||||||
|
self.out_proj.get_weight(),
|
||||||
|
self.out_proj.get_bias(),
|
||||||
|
training=self.training,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
def rel_shift(self, x: Tensor) -> Tensor:
|
||||||
|
"""Compute relative positional encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||||
|
time1 means the length of query vector.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: tensor of shape (batch, head, time1, time2)
|
||||||
|
(note: time2 has the same value as time1, but it is for
|
||||||
|
the key, while time1 is for the query).
|
||||||
|
"""
|
||||||
|
(batch_size, num_heads, time1, n) = x.shape
|
||||||
|
assert n == 2 * time1 - 1
|
||||||
|
# Note: TorchScript requires explicit arg for stride()
|
||||||
|
batch_stride = x.stride(0)
|
||||||
|
head_stride = x.stride(1)
|
||||||
|
time1_stride = x.stride(2)
|
||||||
|
n_stride = x.stride(3)
|
||||||
|
return x.as_strided(
|
||||||
|
(batch_size, num_heads, time1, time1),
|
||||||
|
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||||
|
storage_offset=n_stride * (time1 - 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def multi_head_attention_forward(
|
||||||
|
self,
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
embed_dim_to_check: int,
|
||||||
|
num_heads: int,
|
||||||
|
in_proj_weight: Tensor,
|
||||||
|
in_proj_bias: Tensor,
|
||||||
|
dropout_p: float,
|
||||||
|
out_proj_weight: Tensor,
|
||||||
|
out_proj_bias: Tensor,
|
||||||
|
training: bool = True,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
query, key, value: map a query and a set of key-value pairs to an output.
|
||||||
|
pos_emb: Positional embedding tensor
|
||||||
|
embed_dim_to_check: total dimension of the model.
|
||||||
|
num_heads: parallel attention heads.
|
||||||
|
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||||
|
dropout_p: probability of an element to be zeroed.
|
||||||
|
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
||||||
|
training: apply dropout if is ``True``.
|
||||||
|
key_padding_mask: if provided, specified padding elements in the key will
|
||||||
|
be ignored by the attention. This is an binary mask. When the value is True,
|
||||||
|
the corresponding value on the attention layer will be filled with -inf.
|
||||||
|
need_weights: output attn_output_weights.
|
||||||
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
Inputs:
|
||||||
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||||
|
the embedding dimension.
|
||||||
|
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
|
||||||
|
length, N is the batch size, E is the embedding dimension.
|
||||||
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||||
|
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
||||||
|
will be unchanged. If a BoolTensor is provided, the positions with the
|
||||||
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||||
|
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||||
|
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||||
|
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
||||||
|
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||||
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||||
|
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||||
|
is provided, it will be added to the attention weight.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||||
|
E is the embedding dimension.
|
||||||
|
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||||
|
L is the target sequence length, S is the source sequence length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
|
assert embed_dim == embed_dim_to_check
|
||||||
|
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||||
|
|
||||||
|
head_dim = embed_dim // num_heads
|
||||||
|
assert (
|
||||||
|
head_dim * num_heads == embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
|
scaling = float(head_dim) ** -0.5
|
||||||
|
|
||||||
|
if torch.equal(query, key) and torch.equal(key, value):
|
||||||
|
# self-attention
|
||||||
|
q, k, v = nn.functional.linear(
|
||||||
|
query, in_proj_weight, in_proj_bias
|
||||||
|
).chunk(3, dim=-1)
|
||||||
|
|
||||||
|
elif torch.equal(key, value):
|
||||||
|
# encoder-decoder attention
|
||||||
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
|
_b = in_proj_bias
|
||||||
|
_start = 0
|
||||||
|
_end = embed_dim
|
||||||
|
_w = in_proj_weight[_start:_end, :]
|
||||||
|
if _b is not None:
|
||||||
|
_b = _b[_start:_end]
|
||||||
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
|
||||||
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
|
_b = in_proj_bias
|
||||||
|
_start = embed_dim
|
||||||
|
_end = None
|
||||||
|
_w = in_proj_weight[_start:, :]
|
||||||
|
if _b is not None:
|
||||||
|
_b = _b[_start:]
|
||||||
|
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
|
_b = in_proj_bias
|
||||||
|
_start = 0
|
||||||
|
_end = embed_dim
|
||||||
|
_w = in_proj_weight[_start:_end, :]
|
||||||
|
if _b is not None:
|
||||||
|
_b = _b[_start:_end]
|
||||||
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
|
||||||
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
|
_b = in_proj_bias
|
||||||
|
_start = embed_dim
|
||||||
|
_end = embed_dim * 2
|
||||||
|
_w = in_proj_weight[_start:_end, :]
|
||||||
|
if _b is not None:
|
||||||
|
_b = _b[_start:_end]
|
||||||
|
k = nn.functional.linear(key, _w, _b)
|
||||||
|
|
||||||
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
|
_b = in_proj_bias
|
||||||
|
_start = embed_dim * 2
|
||||||
|
_end = None
|
||||||
|
_w = in_proj_weight[_start:, :]
|
||||||
|
if _b is not None:
|
||||||
|
_b = _b[_start:]
|
||||||
|
v = nn.functional.linear(value, _w, _b)
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
assert (
|
||||||
|
attn_mask.dtype == torch.float32
|
||||||
|
or attn_mask.dtype == torch.float64
|
||||||
|
or attn_mask.dtype == torch.float16
|
||||||
|
or attn_mask.dtype == torch.uint8
|
||||||
|
or attn_mask.dtype == torch.bool
|
||||||
|
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
||||||
|
attn_mask.dtype
|
||||||
|
)
|
||||||
|
if attn_mask.dtype == torch.uint8:
|
||||||
|
warnings.warn(
|
||||||
|
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
|
||||||
|
)
|
||||||
|
attn_mask = attn_mask.to(torch.bool)
|
||||||
|
|
||||||
|
if attn_mask.dim() == 2:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The size of the 2D attn_mask is not correct."
|
||||||
|
)
|
||||||
|
elif attn_mask.dim() == 3:
|
||||||
|
if list(attn_mask.size()) != [
|
||||||
|
bsz * num_heads,
|
||||||
|
query.size(0),
|
||||||
|
key.size(0),
|
||||||
|
]:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The size of the 3D attn_mask is not correct."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"attn_mask's dimension {} is not supported".format(
|
||||||
|
attn_mask.dim()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# attn_mask's dim is 3 now.
|
||||||
|
|
||||||
|
# convert ByteTensor key_padding_mask to bool
|
||||||
|
if (
|
||||||
|
key_padding_mask is not None
|
||||||
|
and key_padding_mask.dtype == torch.uint8
|
||||||
|
):
|
||||||
|
warnings.warn(
|
||||||
|
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||||
|
)
|
||||||
|
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||||
|
|
||||||
|
q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
|
||||||
|
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||||
|
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
|
|
||||||
|
src_len = k.size(0)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||||
|
key_padding_mask.size(0), bsz
|
||||||
|
)
|
||||||
|
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
|
||||||
|
key_padding_mask.size(1), src_len
|
||||||
|
)
|
||||||
|
|
||||||
|
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
||||||
|
|
||||||
|
pos_emb_bsz = pos_emb.size(0)
|
||||||
|
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||||
|
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||||
|
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||||
|
|
||||||
|
q_with_bias_u = (q + self._pos_bias_u()).transpose(
|
||||||
|
1, 2
|
||||||
|
) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
|
q_with_bias_v = (q + self._pos_bias_v()).transpose(
|
||||||
|
1, 2
|
||||||
|
) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
|
# compute attention score
|
||||||
|
# first compute matrix a and matrix c
|
||||||
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
|
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||||
|
matrix_ac = torch.matmul(
|
||||||
|
q_with_bias_u, k
|
||||||
|
) # (batch, head, time1, time2)
|
||||||
|
|
||||||
|
# compute matrix b and matrix d
|
||||||
|
matrix_bd = torch.matmul(
|
||||||
|
q_with_bias_v, p.transpose(-2, -1)
|
||||||
|
) # (batch, head, time1, 2*time1-1)
|
||||||
|
matrix_bd = self.rel_shift(matrix_bd)
|
||||||
|
|
||||||
|
attn_output_weights = (
|
||||||
|
matrix_ac + matrix_bd
|
||||||
|
) # (batch, head, time1, time2)
|
||||||
|
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz * num_heads, tgt_len, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
assert list(attn_output_weights.size()) == [
|
||||||
|
bsz * num_heads,
|
||||||
|
tgt_len,
|
||||||
|
src_len,
|
||||||
|
]
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
if attn_mask.dtype == torch.bool:
|
||||||
|
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
||||||
|
else:
|
||||||
|
attn_output_weights += attn_mask
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz, num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
|
attn_output_weights = attn_output_weights.masked_fill(
|
||||||
|
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz * num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||||
|
attn_output_weights = nn.functional.dropout(
|
||||||
|
attn_output_weights, p=dropout_p, training=training
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
|
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||||
|
attn_output = (
|
||||||
|
attn_output.transpose(0, 1)
|
||||||
|
.contiguous()
|
||||||
|
.view(tgt_len, bsz, embed_dim)
|
||||||
|
)
|
||||||
|
attn_output = nn.functional.linear(
|
||||||
|
attn_output, out_proj_weight, out_proj_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
if need_weights:
|
||||||
|
# average attention weights over heads
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz, num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
|
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
||||||
|
else:
|
||||||
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
|
class ConvolutionModule(nn.Module):
|
||||||
|
"""ConvolutionModule in Conformer model.
|
||||||
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): The number of channels of conv layers.
|
||||||
|
kernel_size (int): Kernerl size of conv layers.
|
||||||
|
bias (bool): Whether to use bias in conv layers (default=True).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, channels: int, kernel_size: int, bias: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""Construct an ConvolutionModule object."""
|
||||||
|
super(ConvolutionModule, self).__init__()
|
||||||
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
|
self.pointwise_conv1 = ScaledConv1d(
|
||||||
|
channels,
|
||||||
|
2 * channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
# after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
|
||||||
|
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
|
||||||
|
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
|
||||||
|
# between 50 and 100 for different channels. This will cause very peaky and
|
||||||
|
# sparse derivatives for the sigmoid gating function, which will tend to make
|
||||||
|
# the loss function not learn effectively. (for most layers the average absolute values
|
||||||
|
# are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
|
||||||
|
# at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
|
||||||
|
# layers, which likely breaks down as 0.5 for the "linear" half and
|
||||||
|
# 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
|
||||||
|
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||||
|
# it will be in a better position to start learning something, i.e. to latch onto
|
||||||
|
# the correct range.
|
||||||
|
self.deriv_balancer1 = ActivationBalancer(
|
||||||
|
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
self.depthwise_conv = ScaledConv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=channels,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.deriv_balancer2 = ActivationBalancer(
|
||||||
|
channel_dim=1, min_positive=0.05, max_positive=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
self.activation = DoubleSwish()
|
||||||
|
|
||||||
|
self.pointwise_conv2 = ScaledConv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
initial_scale=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""Compute convolution module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor (#time, batch, channels).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (#time, batch, channels).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# exchange the temporal dimension and the feature dimension
|
||||||
|
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||||
|
|
||||||
|
# GLU mechanism
|
||||||
|
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||||
|
|
||||||
|
x = self.deriv_balancer1(x)
|
||||||
|
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||||
|
|
||||||
|
# 1D Depthwise Conv
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
|
||||||
|
x = self.deriv_balancer2(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
|
return x.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
feature_dim = 50
|
||||||
|
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||||
|
batch_size = 5
|
||||||
|
seq_len = 20
|
||||||
|
# Just make sure the forward pass runs.
|
||||||
|
f = c(
|
||||||
|
torch.randn(batch_size, seq_len, feature_dim),
|
||||||
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
|
warmup=0.5,
|
||||||
|
)
|
996
egs/librispeech/ASR/conformer_ctc2/decode.py
Executable file
996
egs/librispeech/ASR/conformer_ctc2/decode.py
Executable file
@ -0,0 +1,996 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||||
|
# Fangjun Kuang,
|
||||||
|
# Quandong Wang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from conformer import Conformer
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
|
from icefall.decode import (
|
||||||
|
get_lattice,
|
||||||
|
nbest_decoding,
|
||||||
|
nbest_oracle,
|
||||||
|
one_best_decoding,
|
||||||
|
rescore_with_attention_decoder,
|
||||||
|
rescore_with_n_best_list,
|
||||||
|
rescore_with_rnn_lm,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
|
)
|
||||||
|
from icefall.env import get_env_info
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.rnn_lm.model import RnnLmModel
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
get_texts,
|
||||||
|
load_averaged_model,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=77,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="attention-decoder",
|
||||||
|
help="""Decoding method.
|
||||||
|
Supported values are:
|
||||||
|
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||||
|
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||||
|
It needs neither a lexicon nor an n-gram LM.
|
||||||
|
- (1) ctc-greedy-search. It only use CTC output and a sentence piece
|
||||||
|
model for decoding. It produces the same results with ctc-decoding.
|
||||||
|
- (2) 1best. Extract the best path from the decoding lattice as the
|
||||||
|
decoding result.
|
||||||
|
- (3) nbest. Extract n paths from the decoding lattice; the path
|
||||||
|
with the highest score is the decoding result.
|
||||||
|
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||||
|
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||||
|
the highest score is the decoding result.
|
||||||
|
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||||
|
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||||
|
is the decoding result.
|
||||||
|
- (6) attention-decoder. Extract n paths from the LM rescored
|
||||||
|
lattice, the path with the highest score is the decoding result.
|
||||||
|
- (7) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume
|
||||||
|
you have trained an RNN LM using ./rnn_lm/train.py
|
||||||
|
- (8) nbest-oracle. Its WER is the lower bound of any n-best
|
||||||
|
rescoring method can achieve. Useful for debugging n-best
|
||||||
|
rescoring method.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-decoder-layers",
|
||||||
|
type=int,
|
||||||
|
default=6,
|
||||||
|
help="""Number of decoder layer of transformer decoder.
|
||||||
|
Setting this to 0 will not create the decoder at all (pure CTC model)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""Number of paths for n-best based decoding method.
|
||||||
|
Used only when "method" is one of the following values:
|
||||||
|
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""The scale to be applied to `lattice.scores`.
|
||||||
|
It's needed if you use any kinds of n-best based rescoring.
|
||||||
|
Used only when "method" is one of the following values:
|
||||||
|
nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle
|
||||||
|
A smaller value results in more unique paths.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_ctc2/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lm",
|
||||||
|
help="""The n-gram LM dir.
|
||||||
|
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="rnn_lm/exp",
|
||||||
|
help="""Used only when --method is rnn-lm.
|
||||||
|
It specifies the path to RNN LM exp dir.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-epoch",
|
||||||
|
type=int,
|
||||||
|
default=7,
|
||||||
|
help="""Used only when --method is rnn-lm.
|
||||||
|
It specifies the checkpoint to use.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-avg",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="""Used only when --method is rnn-lm.
|
||||||
|
It specifies the number of checkpoints to average.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-embedding-dim",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Embedding dim of the model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-hidden-dim",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Hidden dim of the model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-num-layers",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of RNN layers the model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnn-lm-tie-weights",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to share the weights between the input embedding layer and the
|
||||||
|
last output linear layer
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
# parameters for conformer
|
||||||
|
"subsampling_factor": 4,
|
||||||
|
"feature_dim": 80,
|
||||||
|
"nhead": 8,
|
||||||
|
"dim_feedforward": 2048,
|
||||||
|
"encoder_dim": 512,
|
||||||
|
"num_encoder_layers": 12,
|
||||||
|
# parameters for decoding
|
||||||
|
"search_beam": 20,
|
||||||
|
"output_beam": 8,
|
||||||
|
"min_active_states": 30,
|
||||||
|
"max_active_states": 10000,
|
||||||
|
"use_double_scores": True,
|
||||||
|
"env_info": get_env_info(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_greedy_search(
|
||||||
|
nnet_output: torch.Tensor,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
memory_key_padding_mask: torch.Tensor,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Apply CTC greedy search
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speech (torch.Tensor): (batch, max_len, feat_dim)
|
||||||
|
speech_length (torch.Tensor): (batch, )
|
||||||
|
Returns:
|
||||||
|
List[List[int]]: best path result
|
||||||
|
"""
|
||||||
|
batch_size = memory.shape[1]
|
||||||
|
# Let's assume B = batch_size
|
||||||
|
encoder_out = memory
|
||||||
|
encoder_mask = memory_key_padding_mask
|
||||||
|
maxlen = encoder_out.size(0)
|
||||||
|
|
||||||
|
ctc_probs = nnet_output # (B, maxlen, vocab_size)
|
||||||
|
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
|
||||||
|
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
|
||||||
|
topk_index = topk_index.masked_fill_(encoder_mask, 0) # (B, maxlen)
|
||||||
|
hyps = [hyp.tolist() for hyp in topk_index]
|
||||||
|
scores = topk_prob.max(1)
|
||||||
|
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
|
||||||
|
return hyps, scores
|
||||||
|
|
||||||
|
|
||||||
|
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||||
|
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
||||||
|
new_hyp: List[int] = []
|
||||||
|
cur = 0
|
||||||
|
while cur < len(hyp):
|
||||||
|
if hyp[cur] != 0:
|
||||||
|
new_hyp.append(hyp[cur])
|
||||||
|
prev = cur
|
||||||
|
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
||||||
|
cur += 1
|
||||||
|
return new_hyp
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
rnn_lm_model: Optional[nn.Module],
|
||||||
|
HLG: Optional[k2.Fsa],
|
||||||
|
H: Optional[k2.Fsa],
|
||||||
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
|
batch: dict,
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
G: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[List[str]]]:
|
||||||
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
|
following format:
|
||||||
|
|
||||||
|
- key: It indicates the setting used for decoding. For example,
|
||||||
|
if no rescoring is used, the key is the string `no_rescore`.
|
||||||
|
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||||
|
where `xxx` is the value of `lm_scale`. An example key is
|
||||||
|
`lm_scale_0.7`
|
||||||
|
- value: It contains the decoding result. `len(value)` equals to
|
||||||
|
batch size. `value[i]` is the decoding result for the i-th
|
||||||
|
utterance in the given batch.
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
|
||||||
|
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||||
|
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
||||||
|
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||||
|
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||||
|
rescoring.
|
||||||
|
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
rnn_lm_model:
|
||||||
|
The neural model for RNN LM.
|
||||||
|
HLG:
|
||||||
|
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||||
|
H:
|
||||||
|
The ctc topo. Used only when params.method is ctc-decoding.
|
||||||
|
bpe_model:
|
||||||
|
The BPE model. Used only when params.method is ctc-decoding.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
sos_id:
|
||||||
|
The token ID of the SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID of the EOS.
|
||||||
|
G:
|
||||||
|
An LM. It is not None when params.method is "nbest-rescoring"
|
||||||
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
|
is a 3-gram LM, while this G is a 4-gram LM.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict. Note: If it decodes to nothing, then return None.
|
||||||
|
"""
|
||||||
|
if HLG is not None:
|
||||||
|
device = HLG.device
|
||||||
|
else:
|
||||||
|
device = H.device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
|
||||||
|
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||||
|
# nnet_output is (N, T, C)
|
||||||
|
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
(
|
||||||
|
supervisions["sequence_idx"],
|
||||||
|
torch.div(
|
||||||
|
supervisions["start_frame"],
|
||||||
|
params.subsampling_factor,
|
||||||
|
rounding_mode="trunc",
|
||||||
|
),
|
||||||
|
torch.div(
|
||||||
|
supervisions["num_frames"],
|
||||||
|
params.subsampling_factor,
|
||||||
|
rounding_mode="trunc",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
).to(torch.int32)
|
||||||
|
|
||||||
|
if H is None:
|
||||||
|
assert HLG is not None
|
||||||
|
decoding_graph = HLG
|
||||||
|
else:
|
||||||
|
assert HLG is None
|
||||||
|
assert bpe_model is not None
|
||||||
|
decoding_graph = H
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "ctc-decoding":
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||||
|
# since we are using H, not HLG here.
|
||||||
|
#
|
||||||
|
# token_ids is a lit-of-list of IDs
|
||||||
|
token_ids = get_texts(best_path)
|
||||||
|
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
key = "ctc-decoding"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.method == "ctc-greedy-search":
|
||||||
|
hyps, _ = ctc_greedy_search(
|
||||||
|
nnet_output,
|
||||||
|
memory,
|
||||||
|
memory_key_padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(hyps)
|
||||||
|
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
key = "ctc-greedy-search"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.method == "nbest-oracle":
|
||||||
|
# Note: You can also pass rescored lattices to it.
|
||||||
|
# We choose the HLG decoded lattice for speed reasons
|
||||||
|
# as HLG decoding is faster and the oracle WER
|
||||||
|
# is only slightly worse than that of rescored lattices.
|
||||||
|
best_path = nbest_oracle(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=supervisions["text"],
|
||||||
|
word_table=word_table,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
oov="<UNK>",
|
||||||
|
)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.method in ["1best", "nbest"]:
|
||||||
|
if params.method == "1best":
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
key = "no_rescore"
|
||||||
|
else:
|
||||||
|
best_path = nbest_decoding(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
use_double_scores=params.use_double_scores,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
assert params.method in [
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
"attention-decoder",
|
||||||
|
"rnn-lm",
|
||||||
|
]
|
||||||
|
|
||||||
|
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||||
|
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||||
|
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||||
|
|
||||||
|
if params.method == "nbest-rescoring":
|
||||||
|
best_path_dict = rescore_with_n_best_list(
|
||||||
|
lattice=lattice,
|
||||||
|
G=G,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
elif params.method == "whole-lattice-rescoring":
|
||||||
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
|
)
|
||||||
|
elif params.method == "attention-decoder":
|
||||||
|
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||||
|
rescored_lattice = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=None,
|
||||||
|
)
|
||||||
|
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||||
|
# `rescore_with_attention_decoder`
|
||||||
|
|
||||||
|
best_path_dict = rescore_with_attention_decoder(
|
||||||
|
lattice=rescored_lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
elif params.method == "rnn-lm":
|
||||||
|
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||||
|
rescored_lattice = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path_dict = rescore_with_rnn_lm(
|
||||||
|
lattice=rescored_lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
rnn_lm_model=rnn_lm_model,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
blank_id=0,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
if best_path_dict is not None:
|
||||||
|
for lm_scale_str, best_path in best_path_dict.items():
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
ans[lm_scale_str] = hyps
|
||||||
|
else:
|
||||||
|
ans = None
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
rnn_lm_model: Optional[nn.Module],
|
||||||
|
HLG: Optional[k2.Fsa],
|
||||||
|
H: Optional[k2.Fsa],
|
||||||
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
G: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
rnn_lm_model:
|
||||||
|
The neural model for RNN LM.
|
||||||
|
HLG:
|
||||||
|
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||||
|
H:
|
||||||
|
The ctc topo. Used only when params.method is ctc-decoding.
|
||||||
|
bpe_model:
|
||||||
|
The BPE model. Used only when params.method is ctc-decoding.
|
||||||
|
word_table:
|
||||||
|
It is the word symbol table.
|
||||||
|
sos_id:
|
||||||
|
The token ID for SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID for EOS.
|
||||||
|
G:
|
||||||
|
An LM. It is not None when params.method is "nbest-rescoring"
|
||||||
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
|
is a 3-gram LM, while this G is a 4-gram LM.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||||
|
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
|
hyps_dict = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
rnn_lm_model=rnn_lm_model,
|
||||||
|
HLG=HLG,
|
||||||
|
H=H,
|
||||||
|
bpe_model=bpe_model,
|
||||||
|
batch=batch,
|
||||||
|
word_table=word_table,
|
||||||
|
G=G,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hyps_dict is not None:
|
||||||
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for hyp_words, ref_text in zip(hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((ref_words, hyp_words))
|
||||||
|
|
||||||
|
results[lm_scale].extend(this_batch)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
len(results) > 0
|
||||||
|
), "It should not decode to empty in the first batch!"
|
||||||
|
this_batch = []
|
||||||
|
hyp_words = []
|
||||||
|
for ref_text in texts:
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((ref_words, hyp_words))
|
||||||
|
|
||||||
|
for lm_scale in results.keys():
|
||||||
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % 100 == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||||
|
):
|
||||||
|
if params.method in ("attention-decoder", "rnn-lm"):
|
||||||
|
# Set it to False since there are too many logs.
|
||||||
|
enable_log = False
|
||||||
|
else:
|
||||||
|
enable_log = True
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
if enable_log:
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
if enable_log:
|
||||||
|
logging.info(
|
||||||
|
"Wrote detailed error stats to {}".format(errs_filename)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
args.lm_dir = Path(args.lm_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
num_classes = max_token_id + 1 # +1 for the blank
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||||
|
params.lang_dir,
|
||||||
|
device=device,
|
||||||
|
sos_token="<sos/eos>",
|
||||||
|
eos_token="<sos/eos>",
|
||||||
|
)
|
||||||
|
sos_id = graph_compiler.sos_id
|
||||||
|
eos_id = graph_compiler.eos_id
|
||||||
|
|
||||||
|
params.num_classes = num_classes
|
||||||
|
params.sos_id = sos_id
|
||||||
|
params.eos_id = eos_id
|
||||||
|
|
||||||
|
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
|
||||||
|
HLG = None
|
||||||
|
H = k2.ctc_topo(
|
||||||
|
max_token=max_token_id,
|
||||||
|
modified=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
bpe_model = spm.SentencePieceProcessor()
|
||||||
|
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||||
|
else:
|
||||||
|
H = None
|
||||||
|
bpe_model = None
|
||||||
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||||
|
)
|
||||||
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
|
if not hasattr(HLG, "lm_scores"):
|
||||||
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
if params.method in (
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
"attention-decoder",
|
||||||
|
"rnn-lm",
|
||||||
|
):
|
||||||
|
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||||
|
logging.info("Loading G_4_gram.fst.txt")
|
||||||
|
logging.warning("It may take 8 minutes.")
|
||||||
|
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||||
|
first_word_disambig_id = lexicon.word_table["#0"]
|
||||||
|
|
||||||
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
# G.aux_labels is not needed in later computations, so
|
||||||
|
# remove it here.
|
||||||
|
del G.aux_labels
|
||||||
|
# CAUTION: The following line is crucial.
|
||||||
|
# Arcs entering the back-off state have label equal to #0.
|
||||||
|
# We have to change it to 0 here.
|
||||||
|
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||||
|
# See https://github.com/k2-fsa/k2/issues/874
|
||||||
|
# for why we need to set G.properties to None
|
||||||
|
G.__dict__["_properties"] = None
|
||||||
|
G = k2.Fsa.from_fsas([G]).to(device)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
# Save a dummy value so that it can be loaded in C++.
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/67902
|
||||||
|
# for why we need to do this.
|
||||||
|
G.dummy = 1
|
||||||
|
|
||||||
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
|
else:
|
||||||
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
|
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||||
|
G = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
|
if params.method in [
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
"attention-decoder",
|
||||||
|
"rnn-lm",
|
||||||
|
]:
|
||||||
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
# it with the whole lattice later
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
G = G.to(device)
|
||||||
|
|
||||||
|
# G.lm_scores is used to replace HLG.lm_scores during
|
||||||
|
# LM rescoring.
|
||||||
|
G.lm_scores = G.scores.clone()
|
||||||
|
else:
|
||||||
|
G = None
|
||||||
|
|
||||||
|
model = Conformer(
|
||||||
|
num_features=params.feature_dim,
|
||||||
|
nhead=params.nhead,
|
||||||
|
d_model=params.encoder_dim,
|
||||||
|
num_classes=num_classes,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(
|
||||||
|
params.exp_dir, iteration=-params.iter
|
||||||
|
)[: params.avg]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(
|
||||||
|
params.exp_dir, iteration=-params.iter
|
||||||
|
)[: params.avg + 1]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
rnn_lm_model = None
|
||||||
|
if params.method == "rnn-lm":
|
||||||
|
rnn_lm_model = RnnLmModel(
|
||||||
|
vocab_size=params.num_classes,
|
||||||
|
embedding_dim=params.rnn_lm_embedding_dim,
|
||||||
|
hidden_dim=params.rnn_lm_hidden_dim,
|
||||||
|
num_layers=params.rnn_lm_num_layers,
|
||||||
|
tie_weights=params.rnn_lm_tie_weights,
|
||||||
|
)
|
||||||
|
if params.rnn_lm_avg == 1:
|
||||||
|
load_checkpoint(
|
||||||
|
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
|
||||||
|
rnn_lm_model,
|
||||||
|
)
|
||||||
|
rnn_lm_model.to(device)
|
||||||
|
else:
|
||||||
|
rnn_lm_model = load_averaged_model(
|
||||||
|
params.rnn_lm_exp_dir,
|
||||||
|
rnn_lm_model,
|
||||||
|
params.rnn_lm_epoch,
|
||||||
|
params.rnn_lm_avg,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
rnn_lm_model.eval()
|
||||||
|
|
||||||
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
|
|
||||||
|
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
|
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
|
||||||
|
test_sets = ["test-clean", "test-other"]
|
||||||
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
rnn_lm_model=rnn_lm_model,
|
||||||
|
HLG=HLG,
|
||||||
|
H=H,
|
||||||
|
bpe_model=bpe_model,
|
||||||
|
word_table=lexicon.word_table,
|
||||||
|
G=G,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(
|
||||||
|
params=params, test_set_name=test_set, results_dict=results_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
281
egs/librispeech/ASR/conformer_ctc2/export.py
Executable file
281
egs/librispeech/ASR/conformer_ctc2/export.py
Executable file
@ -0,0 +1,281 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Quandong Wang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This script converts several saved checkpoints
|
||||||
|
# to a single one using model averaging.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
./conformer_ctc2/export.py \
|
||||||
|
--exp-dir ./conformer_ctc2/exp \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10
|
||||||
|
|
||||||
|
It will generate a file exp_dir/pretrained.pt
|
||||||
|
|
||||||
|
To use the generated file with `conformer_ctc2/decode.py`,
|
||||||
|
you can do:
|
||||||
|
|
||||||
|
cd /path/to/exp_dir
|
||||||
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
|
cd /path/to/egs/librispeech/ASR
|
||||||
|
./conformer_ctc2/decode.py \
|
||||||
|
--exp-dir ./conformer_ctc2/exp \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 100
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from decode import get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from conformer import Conformer
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-decoder-layers",
|
||||||
|
type=int,
|
||||||
|
default=6,
|
||||||
|
help="""Number of decoder layer of transformer decoder.
|
||||||
|
Setting this to 0 will not create the decoder at all (pure CTC model)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_ctc2/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="""True to save a model after applying torch.jit.script.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
num_classes = max_token_id + 1 # +1 for the blank
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
|
||||||
|
model = Conformer(
|
||||||
|
num_features=params.feature_dim,
|
||||||
|
nhead=params.nhead,
|
||||||
|
d_model=params.encoder_dim,
|
||||||
|
num_classes=num_classes,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(
|
||||||
|
params.exp_dir, iteration=-params.iter
|
||||||
|
)[: params.avg]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(
|
||||||
|
params.exp_dir, iteration=-params.iter
|
||||||
|
)[: params.avg + 1]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if params.jit:
|
||||||
|
logging.info("Using torch.jit.script")
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
model.save(str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
else:
|
||||||
|
logging.info("Not using torch.jit.script")
|
||||||
|
# Save it using a format so that it can be loaded
|
||||||
|
# by :func:`load_checkpoint`
|
||||||
|
filename = params.exp_dir / "pretrained.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
1
egs/librispeech/ASR/conformer_ctc2/label_smoothing.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc2/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../conformer_ctc/label_smoothing.py
|
1
egs/librispeech/ASR/conformer_ctc2/optim.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc2/optim.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/optim.py
|
1
egs/librispeech/ASR/conformer_ctc2/scaling.py
Symbolic link
1
egs/librispeech/ASR/conformer_ctc2/scaling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless2/scaling.py
|
121
egs/librispeech/ASR/conformer_ctc2/subsampling.py
Normal file
121
egs/librispeech/ASR/conformer_ctc2/subsampling.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||||
|
# 2022 Xiaomi Corporation (author: Quandong Wang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling import (
|
||||||
|
ActivationBalancer,
|
||||||
|
BasicNorm,
|
||||||
|
DoubleSwish,
|
||||||
|
ScaledConv2d,
|
||||||
|
ScaledLinear,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSubsampling(nn.Module):
|
||||||
|
"""Convolutional 2D subsampling (to 1/4 length).
|
||||||
|
|
||||||
|
Convert an input of shape (N, T, idim) to an output
|
||||||
|
with shape (N, T', odim), where
|
||||||
|
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||||
|
|
||||||
|
It is based on
|
||||||
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
layer1_channels: int = 8,
|
||||||
|
layer2_channels: int = 32,
|
||||||
|
layer3_channels: int = 128,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_channels:
|
||||||
|
Number of channels in. The input shape is (N, T, in_channels).
|
||||||
|
Caution: It requires: T >=7, in_channels >=7
|
||||||
|
out_channels
|
||||||
|
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
|
||||||
|
layer1_channels:
|
||||||
|
Number of channels in layer1
|
||||||
|
layer1_channels:
|
||||||
|
Number of channels in layer2
|
||||||
|
"""
|
||||||
|
assert in_channels >= 7
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
ScaledConv2d(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=layer1_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
|
ActivationBalancer(channel_dim=1),
|
||||||
|
DoubleSwish(),
|
||||||
|
ScaledConv2d(
|
||||||
|
in_channels=layer1_channels,
|
||||||
|
out_channels=layer2_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
|
ActivationBalancer(channel_dim=1),
|
||||||
|
DoubleSwish(),
|
||||||
|
ScaledConv2d(
|
||||||
|
in_channels=layer2_channels,
|
||||||
|
out_channels=layer3_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
),
|
||||||
|
ActivationBalancer(channel_dim=1),
|
||||||
|
DoubleSwish(),
|
||||||
|
)
|
||||||
|
self.out = ScaledLinear(
|
||||||
|
layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels
|
||||||
|
)
|
||||||
|
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||||
|
# itself has learned scale, so the extra degree of freedom is not
|
||||||
|
# needed.
|
||||||
|
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||||
|
# constrain median of output to be close to zero.
|
||||||
|
self.out_balancer = ActivationBalancer(
|
||||||
|
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
Its shape is (N, T, idim).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||||
|
"""
|
||||||
|
# On entry, x is (N, T, idim)
|
||||||
|
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||||
|
x = self.conv(x)
|
||||||
|
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||||
|
x = self.out_norm(x)
|
||||||
|
x = self.out_balancer(x)
|
||||||
|
return x
|
1119
egs/librispeech/ASR/conformer_ctc2/train.py
Executable file
1119
egs/librispeech/ASR/conformer_ctc2/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1092
egs/librispeech/ASR/conformer_ctc2/transformer.py
Normal file
1092
egs/librispeech/ASR/conformer_ctc2/transformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -643,7 +643,8 @@ class ScaledEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = "{num_embeddings}, {embedding_dim}, scale={scale}"
|
# s = "{num_embeddings}, {embedding_dim}, scale={scale}"
|
||||||
|
s = "{num_embeddings}, {embedding_dim}"
|
||||||
if self.padding_idx is not None:
|
if self.padding_idx is not None:
|
||||||
s += ", padding_idx={padding_idx}"
|
s += ", padding_idx={padding_idx}"
|
||||||
if self.scale_grad_by_freq is not False:
|
if self.scale_grad_by_freq is not False:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user