RNN-T Conformer training for LibriSpeech (#143)

* Begin to add RNN-T training for librispeech.

* Copy files from conformer_ctc.

Will edit it.

* Use conformer/transformer model as encoder.

* Begin to add training script.

* Add training code.

* Remove long utterances to avoid OOM when a large max_duraiton is used.

* Begin to add decoding script.

* Add decoding script.

* Minor fixes.

* Add beam search.

* Use LSTM layers for the encoder.

Need more tunings.

* Use stateless decoder.

* Minor fixes to make it ready for merge.

* Fix README.

* Update RESULT.md to include RNN-T Conformer.

* Minor fixes.

* Fix tests.

* Minor fixes.

* Minor fixes.

* Fix tests.
This commit is contained in:
Fangjun Kuang 2021-12-18 07:42:51 +08:00 committed by GitHub
parent 76a51bf037
commit 1d44da845b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 8964 additions and 11 deletions

View File

@ -103,8 +103,10 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc
pytest -v -s
cd ..
pytest -v -s ./transducer
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer
pytest -v -s
fi
- name: Run tests
if: startsWith(matrix.os, 'macos')
@ -120,5 +122,7 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc
pytest -v -s
cd ..
pytest -v -s ./transducer
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer
pytest -v -s
fi

1
.gitignore vendored
View File

@ -8,3 +8,4 @@ exp*/
download
*.bak
*-bak
*bak.py

View File

@ -34,8 +34,11 @@ We do provide a Colab notebook for this recipe.
### LibriSpeech
We provide two models for this recipe: [conformer CTC model][LibriSpeech_conformer_ctc]
and [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc].
We provide 3 models for this recipe:
- [conformer CTC model][LibriSpeech_conformer_ctc]
- [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc]
- [RNN-T Conformer model][LibriSpeech_transducer]
#### Conformer CTC Model
@ -58,6 +61,20 @@ The WER for this model is:
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
#### RNN-T Conformer model
Using Conformer as encoder.
The best WER with greedy search is:
| | test-clean | test-other |
|-----|------------|------------|
| WER | 3.16 | 7.71 |
We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)
### Aishell
We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc]
@ -125,6 +142,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc
[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc
[LibriSpeech_transducer]: egs/librispeech/ASR/transducer
[Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc
[Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc
[TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc

View File

@ -1,5 +1,51 @@
## Results
### LibriSpeech BPE training results (RNN-T)
#### 2021-12-17
RNN-T + Conformer encoder
The best WER is
| | test-clean | test-other |
|-----|------------|------------|
| WER | 3.16 | 7.71 |
using `--epoch 26 --avg 12` during decoding with greedy search.
The training command to reproduce the above WER is:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer/exp-lr-2.5-full \
--full-libri 1 \
--max-duration 250 \
--lr-factor 2.5
```
The decoding command is:
```
epoch=26
avg=12
./transducer/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer/exp-lr-2.5-full \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100
```
You can find the tensorboard log at: <https://tensorboard.dev/experiment/PYIbeD6zRJez1ViXaRqqeg/>
### LibriSpeech BPE training results (Conformer-CTC)
#### 2021-11-09

View File

@ -428,8 +428,6 @@ def decode_dataset(
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:

View File

@ -0,0 +1,215 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()` in transducer/train.py
for usage.
"""
from lhotse import load_manifest
def main():
path = "./data/fbank/cuts_train-clean-100.json.gz"
path = "./data/fbank/cuts_train-clean-360.json.gz"
path = "./data/fbank/cuts_train-other-500.json.gz"
path = "./data/fbank/cuts_dev-clean.json.gz"
path = "./data/fbank/cuts_dev-other.json.gz"
path = "./data/fbank/cuts_test-clean.json.gz"
path = "./data/fbank/cuts_test-other.json.gz"
cuts = load_manifest(path)
cuts.describe()
if __name__ == "__main__":
main()
"""
## train-clean-100
Cuts count: 85617
Total duration (hours): 303.8
Speech duration (hours): 303.8 (100.0%)
***
Duration statistics (seconds):
mean 12.8
std 3.8
min 1.3
0.1% 1.9
0.5% 2.2
1% 2.5
5% 4.2
10% 6.4
25% 11.4
50% 13.8
75% 15.3
90% 16.7
95% 17.3
99% 18.1
99.5% 18.4
99.9% 18.8
max 27.2
## train-clean-360
Cuts count: 312042
Total duration (hours): 1098.2
Speech duration (hours): 1098.2 (100.0%)
***
Duration statistics (seconds):
mean 12.7
std 3.8
min 1.0
0.1% 1.8
0.5% 2.2
1% 2.5
5% 4.2
10% 6.2
25% 11.2
50% 13.7
75% 15.3
90% 16.6
95% 17.3
99% 18.1
99.5% 18.4
99.9% 18.8
max 33.0
## train-other 500
Cuts count: 446064
Total duration (hours): 1500.6
Speech duration (hours): 1500.6 (100.0%)
***
Duration statistics (seconds):
mean 12.1
std 4.2
min 0.8
0.1% 1.7
0.5% 2.1
1% 2.3
5% 3.5
10% 5.0
25% 9.8
50% 13.4
75% 15.1
90% 16.5
95% 17.2
99% 18.1
99.5% 18.4
99.9% 18.9
max 31.0
## dev-clean
Cuts count: 2703
Total duration (hours): 5.4
Speech duration (hours): 5.4 (100.0%)
***
Duration statistics (seconds):
mean 7.2
std 4.7
min 1.4
0.1% 1.6
0.5% 1.8
1% 1.9
5% 2.4
10% 2.7
25% 3.8
50% 5.9
75% 9.3
90% 13.3
95% 16.4
99% 23.8
99.5% 28.5
99.9% 32.3
max 32.6
## dev-other
Cuts count: 2864
Total duration (hours): 5.1
Speech duration (hours): 5.1 (100.0%)
***
Duration statistics (seconds):
mean 6.4
std 4.3
min 1.1
0.1% 1.3
0.5% 1.7
1% 1.8
5% 2.2
10% 2.6
25% 3.5
50% 5.3
75% 7.9
90% 12.0
95% 15.0
99% 22.2
99.5% 27.1
99.9% 32.4
max 35.2
## test-clean
Cuts count: 2620
Total duration (hours): 5.4
Speech duration (hours): 5.4 (100.0%)
***
Duration statistics (seconds):
mean 7.4
std 5.2
min 1.3
0.1% 1.6
0.5% 1.8
1% 2.0
5% 2.3
10% 2.7
25% 3.7
50% 5.8
75% 9.6
90% 14.6
95% 17.8
99% 25.5
99.5% 28.4
99.9% 32.8
max 35.0
## test-other
Cuts count: 2939
Total duration (hours): 5.3
Speech duration (hours): 5.3 (100.0%)
***
Duration statistics (seconds):
mean 6.5
std 4.4
min 1.2
0.1% 1.5
0.5% 1.8
1% 1.9
5% 2.3
10% 2.6
25% 3.4
50% 5.2
75% 8.2
90% 12.6
95% 15.8
99% 21.4
99.5% 23.8
99.9% 33.5
max 34.5
"""

View File

@ -0,0 +1,19 @@
## Introduction
The encoder consists of Conformer layers in this folder. You can use the
following command to start the training:
```bash
cd egs/librispeech/ASR
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer/exp \
--full-libri 1 \
--max-duration 250 \
--lr-factor 2.5
```

View File

@ -0,0 +1 @@
../tdnn_lstm_ctc/asr_datamodule.py

View File

@ -0,0 +1,212 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
hyp = []
max_u = 1000 # terminate after this number of steps
u = 0
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out)
# logits is (1, 1, 1, vocab_size)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
# TODO: Use logits.argmax()
y = log_prob.argmax()
if y != blank_id:
hyp.append(y.item())
y = y.reshape(1, 1)
decoder_out, (h, c) = model.decoder(y, (h, c))
u += 1
else:
t += 1
return hyp
@dataclass
class Hypothesis:
ys: List[int] # the predicted sequences so far
log_prob: float # The log prob of ys
# Optional decoder state. We assume it is LSTM for now,
# so the state is a tuple (h, c)
decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 5,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)]
max_u = 20000 # terminate after this number of steps
u = 0
cache: Dict[
str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
] = {}
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
A = B
B = []
# for hyp in A:
# for h in A:
# if h.ys == hyp.ys[:-1]:
# # update the score of hyp
# decoder_input = torch.tensor(
# [h.ys[-1]], device=device
# ).reshape(1, 1)
# decoder_out, _ = model.decoder(
# decoder_input, h.decoder_state
# )
# logits = model.joiner(current_encoder_out, decoder_out)
# log_prob = logits.log_softmax(dim=-1)
# log_prob = log_prob.squeeze()
# hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item()
while u < max_u:
y_star = max(A, key=lambda hyp: hyp.log_prob)
A.remove(y_star)
# Note: y_star.ys is unhashable, i.e., cannot be used
# as a key into a dict
cached_key = "_".join(map(str, y_star.ys))
if cached_key not in cache:
decoder_input = torch.tensor(
[y_star.ys[-1]], device=device
).reshape(1, 1)
decoder_out, decoder_state = model.decoder(
decoder_input,
y_star.decoder_state,
)
cache[cached_key] = (decoder_out, decoder_state)
else:
decoder_out, decoder_state = cache[cached_key]
logits = model.joiner(current_encoder_out, decoder_out)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
# If we choose blank here, add the new hypothesis to B.
# Otherwise, add the new hypothesis to A
# First, choose blank
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
# ys[:] returns a copy of ys
new_y_star = Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
# Caution: Use y_star.decoder_state here
decoder_state=y_star.decoder_state,
)
B.append(new_y_star)
# Second, choose other labels
for i, v in enumerate(log_prob.tolist()):
if i in (blank_id, sos_id):
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
decoder_state=decoder_state,
)
A.append(new_hyp)
u += 1
# check whether B contains more than "beam" elements more probable
# than the most probable in A
A_most_probable = max(A, key=lambda hyp: hyp.log_prob)
B = sorted(
[hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob],
key=lambda hyp: hyp.log_prob,
reverse=True,
)
if len(B) >= beam:
B = B[:beam]
break
t += 1
best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:]))
ys = best_hyp.ys[1:] # [1:] to remove the blank
return ys

View File

@ -0,0 +1,922 @@
#!/usr/bin/env python3
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from transformer import Transformer
from icefall.utils import make_pad_mask
class Conformer(Transformer):
"""
Args:
num_features (int): Number of input features
output_dim (int): Number of output dimension
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension
nhead (int): number of head
dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend.
"""
def __init__(
self,
num_features: int,
output_dim: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
output_dim=output_dim,
subsampling_factor=subsampling_factor,
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers,
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
use_feat_batchnorm=use_feat_batchnorm,
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
cnn_module_kernel,
normalize_before,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = nn.LayerNorm(d_model)
else:
# Note: TorchScript detects that self.after_norm could be used inside forward()
# and throws an error without this change.
self.after_norm = identity
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
if self.normalize_before:
x = self.after_norm(x)
logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths
class ConformerEncoderLayer(nn.Module):
"""
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
normalize_before: whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = encoder_layer(src, pos_emb)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout)
self.normalize_before = normalize_before
def forward(
self,
src: Tensor,
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
# macaron style feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# multi-headed self-attention module
residual = src
if self.normalize_before:
src = self.norm_mha(src)
src_att = self.self_attn(
src,
src,
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
if not self.normalize_before:
src = self.norm_conv(src)
# feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
if self.normalize_before:
src = self.norm_final(src)
return src
class ConformerEncoder(nn.TransformerEncoder):
r"""ConformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = conformer_encoder(src, pos_emb)
"""
def __init__(
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
) -> None:
super(ConformerEncoder, self).__init__(
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
)
def forward(
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
output = src
for mod in self.layers:
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
if self.norm is not None:
output = self.norm(output)
return output
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
Examples::
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self._reset_parameters()
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
return self.multi_head_attention_forward(
query,
key,
value,
pos_emb,
self.embed_dim,
self.num_heads,
self.in_proj.weight,
self.in_proj.bias,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
)
def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
(note: time2 has the same value as time1, but it is for
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time1),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
def multi_head_attention_forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
length, N is the batch size, E is the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
or attn_mask.dtype == torch.float64
or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8
or attn_mask.dtype == torch.bool
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
attn_mask.dtype
)
if attn_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
)
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz
)
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
key_padding_mask.size(1), src_len
)
q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(
1, 2
) # (batch, head, time1, d_k)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
attn_output_weights = (
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
def identity(x):
return x

View File

@ -0,0 +1,463 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./transducer/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=26,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=12,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=5,
help="Used only when --decoding-method is beam_search",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
else:
return {f"beam_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search")
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,101 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
import torch.nn as nn
# TODO(fangjun): Support switching between LSTM and GRU
class Decoder(nn.Module):
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
sos_id: int,
num_layers: int,
hidden_dim: int,
output_dim: int,
embedding_dropout: float = 0.0,
rnn_dropout: float = 0.0,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
sos_id:
The ID of the SOS symbol.
num_layers:
Number of LSTM layers.
hidden_dim:
Hidden dimension of LSTM layers.
output_dim:
Output dimension of the decoder.
embedding_dropout:
Dropout rate for the embedding layer.
rnn_dropout:
Dropout for LSTM layers.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.embedding_dropout = nn.Dropout(embedding_dropout)
# TODO(fangjun): Use layer normalized LSTM
self.rnn = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=rnn_dropout,
)
self.blank_id = blank_id
self.sos_id = sos_id
self.output_linear = nn.Linear(hidden_dim, output_dim)
def forward(
self,
y: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
y:
A 2-D tensor of shape (N, U) with BOS prepended.
states:
A tuple of two tensors containing the states information of
LSTM layers in this decoder.
Returns:
Return a tuple containing:
- rnn_output, a tensor of shape (N, U, C)
- (h, c), containing the state information for LSTM layers.
Both are of shape (num_layers, N, C)
"""
embeding_out = self.embedding(y)
embeding_out = self.embedding_dropout(embeding_out)
rnn_out, (h, c) = self.rnn(embeding_out, states)
out = self.output_linear(rnn_out)
return out, (h, c)

View File

@ -0,0 +1,43 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
class EncoderInterface(nn.Module):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (batch_size, input_seq_len, num_features)
containing the input features.
x_lens:
A tensor of shape (batch_size,) containing the number of frames
in `x` before padding.
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
containing unnormalized probabilities, i.e., the output of a
linear layer.
- encoder_out_lens, a tensor of shape (batch_size,) containing
the number of frames in `encoder_out` before padding.
"""
raise NotImplementedError("Please implement it in a subclass")

View File

@ -0,0 +1,250 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./transducer/export.py \
--exp-dir ./transducer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 26 \
--avg 12
It will generate a file exp_dir/pretrained.pt
To use the generated file with `transducer/decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./transducer/decode.py \
--exp-dir ./transducer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1 \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=26,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=12,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,55 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.output_linear = nn.Linear(input_dim, output_dim)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, C).
decoder_out:
Output from the decoder. Its shape is (N, U, C).
Returns:
Return a tensor of shape (N, T, U, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == decoder_out.size(2)
encoder_out = encoder_out.unsqueeze(2)
# Now encoder_out is (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1)
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out
logit = F.relu(logit)
output = self.output_linear(logit)
return output

View File

@ -0,0 +1,127 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2
import torch
import torch.nn as nn
import torchaudio
import torchaudio.functional
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
two attributes: `blank_id` and `sos_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface)
assert hasattr(decoder, "blank_id")
assert hasattr(decoder, "sos_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
Returns:
Return the transducer loss.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id
sos_y = add_sos(y, sos_id=sos_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
decoder_out, _ = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
loss = torchaudio.functional.rnnt_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
)
return loss

View File

@ -0,0 +1,299 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./transducer/pretrained.py \
--checkpoint ./transducer/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./transducer/exp/epoch-xx.pt`.
Note: ./transducer/exp/pretrained.pt is generated by
./transducer/export.py
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import beam_search, greedy_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--beam-size",
type=int,
default=5,
help="Used only when --method is beam_search",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -212,6 +212,24 @@ class LayerNormLSTMCell(nn.Module):
if "layernorm" not in name:
nn.init.uniform_(weight, -stdv, stdv)
if "bias_ih" in name or "bias_hh" in name:
# See the paper
# An Empirical Exploration of Recurrent Network Architectures
# https://proceedings.mlr.press/v37/jozefowicz15.pdf
#
# It recommends initializing the bias of the forget gate to
# a large value, such as 1 or 2. In PyTorch, there are two
# biases for the forget gate, we set both of them to 1 here.
#
# See also https://arxiv.org/pdf/1804.04849.pdf
assert weight.ndim == 1
# Layout of the bias:
# | in_gate | forget_gate | cell_gate | output_gate |
start = weight.numel() // 4
end = weight.numel() // 2
with torch.no_grad():
weight[start:end].fill_(1.0)
class LayerNormLSTMLayer(nn.Module):
"""

View File

@ -0,0 +1 @@
../conformer_ctc/subsampling.py

View File

@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer/test_conformer.py
"""
import torch
from conformer import Conformer
def test_conformer():
output_dim = 1024
conformer = Conformer(
num_features=80,
output_dim=output_dim,
subsampling_factor=4,
d_model=512,
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
use_feat_batchnorm=True,
)
N = 3
T = 100
C = 80
x = torch.randn(N, T, C)
x_lens = torch.tensor([50, 100, 80])
logits, logit_lens = conformer(x, x_lens)
expected_T = ((T - 1) // 2 - 1) // 2
assert logits.shape == (N, expected_T, output_dim)
assert logit_lens.max().item() == expected_T
print(logits.shape)
print(logit_lens)
def main():
test_conformer()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,69 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer/test_decoder.py
"""
import torch
from decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
sos_id = 2
embedding_dim = 128
num_layers = 2
hidden_dim = 6
output_dim = 8
N = 3
U = 5
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
sos_id=sos_id,
num_layers=num_layers,
hidden_dim=hidden_dim,
output_dim=output_dim,
embedding_dropout=0.0,
rnn_dropout=0.0,
)
x = torch.randint(1, vocab_size, (N, U))
decoder_out, (h, c) = decoder(x)
assert decoder_out.shape == (N, U, output_dim)
assert h.shape == (num_layers, N, hidden_dim)
assert c.shape == (num_layers, N, hidden_dim)
decoder_out, (h, c) = decoder(x, (h, c))
assert decoder_out.shape == (N, U, output_dim)
assert h.shape == (num_layers, N, hidden_dim)
assert c.shape == (num_layers, N, hidden_dim)
def main():
test_decoder()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer/test_joiner.py
"""
import torch
from joiner import Joiner
def test_joiner():
N = 2
T = 3
C = 4
U = 5
joiner = Joiner(C, 10)
encoder_out = torch.rand(N, T, C)
decoder_out = torch.rand(N, U, C)
joint = joiner(encoder_out, decoder_out)
assert joint.shape == (N, T, U, 10)
def main():
test_joiner()
if __name__ == "__main__":
main()

View File

@ -15,9 +15,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer/test_rnn.py
"""
import torch
import torch.nn as nn
from transducer.rnn import (
from rnn import (
LayerNormGRU,
LayerNormGRUCell,
LayerNormGRULayer,
@ -499,6 +505,28 @@ def test_layernorm_lstm_with_projection_forward(device="cpu"):
assert_allclose(x.grad, x_clone.grad)
def test_lstm_forget_gate_bias(device="cpu"):
input_size = 2
hidden_size = 3
num_layers = 4
bias = True
lstm = LayerNormLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
ln=nn.Identity,
device=device,
)
for name, weight in lstm.named_parameters():
if "bias_hh" in name or "bias_ih" in name:
start = weight.numel() // 4
end = weight.numel() // 2
expected = torch.ones(hidden_size).to(weight)
assert torch.all(torch.eq(weight[start:end], expected))
def test_layernorm_gru_cell_jit(device="cpu"):
input_size = 10
hidden_size = 20
@ -735,6 +763,8 @@ def _test_lstm(device):
test_layernorm_lstm_with_projection_jit(device)
test_layernorm_lstm_forward(device)
test_layernorm_lstm_with_projection_forward(device)
#
test_lstm_forget_gate_bias(device)
def _test_gru(device):

View File

@ -0,0 +1,89 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer/test_transducer.py
"""
import k2
import torch
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
def test_transducer():
# encoder params
input_dim = 10
output_dim = 20
# decoder params
vocab_size = 3
blank_id = 0
sos_id = 2
embedding_dim = 128
num_layers = 2
encoder = Conformer(
num_features=input_dim,
output_dim=output_dim,
subsampling_factor=4,
d_model=512,
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
use_feat_batchnorm=True,
)
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
sos_id=sos_id,
num_layers=num_layers,
hidden_dim=output_dim,
output_dim=output_dim,
embedding_dropout=0.0,
rnn_dropout=0.0,
)
joiner = Joiner(output_dim, vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
N = y.dim0
T = 50
x = torch.rand(N, T, input_dim)
x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32)
x_lens[0] = T
loss = transducer(x, x_lens, y)
print(loss)
def main():
test_transducer()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer/test_transformer.py
"""
import torch
from transformer import Transformer
def test_transformer():
output_dim = 1024
transformer = Transformer(
num_features=80,
output_dim=output_dim,
subsampling_factor=4,
d_model=512,
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
use_feat_batchnorm=True,
)
N = 3
T = 100
C = 80
x = torch.randn(N, T, C)
x_lens = torch.tensor([50, 100, 80])
logits, logit_lens = transformer(x, x_lens)
expected_T = ((T - 1) // 2 - 1) // 2
assert logits.shape == (N, expected_T, output_dim)
assert logit_lens.max().item() == expected_T
print(logits.shape)
print(logit_lens)
def main():
test_transformer()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,743 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer/exp \
--full-libri 1 \
--max-duration 250 \
--lr-factor 2.5
"""
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
transducer/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lr-factor",
type=float,
default=3.0,
help="The lr_factor for Noam optimizer",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y)
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 800
params.warm_step = 8000
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,429 @@
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import make_pad_mask
class Transformer(EncoderInterface):
def __init__(
self,
num_features: int,
output_dim: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
"""
Args:
num_features:
The input dimension of the model.
output_dim:
The output dimension of the model.
subsampling_factor:
Number of output frames is num_in_frames // subsampling_factor.
Currently, subsampling_factor MUST be 4.
d_model:
Attention dimension.
nhead:
Number of heads in multi-head attention.
Must satisfy d_model // nhead == 0.
dim_feedforward:
The output dimension of the feedforward layers in encoder.
num_encoder_layers:
Number of encoder layers.
dropout:
Dropout in encoder.
normalize_before:
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features
self.output_dim = output_dim
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
if vgg_frontend:
self.encoder_embed = VggSubsampling(num_features, d_model)
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = PositionalEncoding(d_model, dropout)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
normalize_before=normalize_before,
)
if normalize_before:
encoder_norm = nn.LayerNorm(d_model)
else:
encoder_norm = None
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_encoder_layers,
norm=encoder_norm,
)
# TODO(fangjun): remove dropout
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), nn.Linear(d_model, output_dim)
)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths
class TransformerEncoderLayer(nn.Module):
"""
Modified from torch.nn.TransformerEncoderLayer.
Add support of normalize_before,
i.e., use layer_norm before the first block.
Args:
d_model:
the number of expected features in the input (required).
nhead:
the number of heads in the multiheadattention models (required).
dim_feedforward:
the dimension of the feedforward network model (default=2048).
dropout:
the dropout value (default=0.1).
activation:
the activation function of intermediate layer, relu or
gelu (default=relu).
normalize_before:
whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: str = "relu",
normalize_before: bool = True,
) -> None:
super(TransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = nn.functional.relu
super(TransformerEncoderLayer, self).__setstate__(state)
def forward(
self,
src: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional)
Shape:
src: (S, N, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length,
N is the batch size, E is the feature number
"""
residual = src
if self.normalize_before:
src = self.norm1(src)
src2 = self.self_attn(
src,
src,
src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout1(src2)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src2)
if not self.normalize_before:
src = self.norm2(src)
return src
def _get_activation_fn(activation: str):
if activation == "relu":
return nn.functional.relu
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module):
"""This class implements the positional encoding
proposed in the following paper:
- Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
Note::
1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
= exp(-1* 2i / d_model * log(100000))
= exp(2i * -(log(10000) / d_model))
"""
def __init__(self, d_model: int, dropout: float = 0.1) -> None:
"""
Args:
d_model:
Embedding dimension.
dropout:
Dropout probability to be applied to the output of this module.
"""
super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout)
# not doing: self.pe = None because of errors thrown by torchscript
self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required.
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
is (N, T, d_model). If T > T1, then we change the shape of self.pe
to (N, T, d_model). Otherwise, nothing is done.
Args:
x:
It is a tensor of shape (N, T, C).
Returns:
Return None.
"""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Add positional encoding.
Args:
x:
Its shape is (N, T, C)
Returns:
Return a tensor of shape (N, T, C)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :]
return self.dropout(x)
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* min(step ** (-0.5), step * self.warmup ** (-1.5))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)

View File

@ -0,0 +1,19 @@
## Introduction
The encoder consists of LSTM layers in this folder. You can use the
following command to start the training:
```bash
cd egs/librispeech/ASR
export CUDA_VISIBLE_DEVICES="0,1,2"
./transducer_lstm/train.py \
--world-size 3 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_lstm/exp \
--full-libri 1 \
--max-duration 300 \
--lr-factor 3
```

View File

@ -0,0 +1 @@
../transducer/asr_datamodule.py

View File

@ -0,0 +1,212 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
hyp = []
max_u = 1000 # terminate after this number of steps
u = 0
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out)
# logits is (1, 1, 1, vocab_size)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
# TODO: Use logits.argmax()
y = log_prob.argmax()
if y != blank_id:
hyp.append(y.item())
y = y.reshape(1, 1)
decoder_out, (h, c) = model.decoder(y, (h, c))
u += 1
else:
t += 1
return hyp
@dataclass
class Hypothesis:
ys: List[int] # the predicted sequences so far
log_prob: float # The log prob of ys
# Optional decoder state. We assume it is LSTM for now,
# so the state is a tuple (h, c)
decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 5,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)]
max_u = 20000 # terminate after this number of steps
u = 0
cache: Dict[
str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
] = {}
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
A = B
B = []
# for hyp in A:
# for h in A:
# if h.ys == hyp.ys[:-1]:
# # update the score of hyp
# decoder_input = torch.tensor(
# [h.ys[-1]], device=device
# ).reshape(1, 1)
# decoder_out, _ = model.decoder(
# decoder_input, h.decoder_state
# )
# logits = model.joiner(current_encoder_out, decoder_out)
# log_prob = logits.log_softmax(dim=-1)
# log_prob = log_prob.squeeze()
# hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item()
while u < max_u:
y_star = max(A, key=lambda hyp: hyp.log_prob)
A.remove(y_star)
# Note: y_star.ys is unhashable, i.e., cannot be used
# as a key into a dict
cached_key = "_".join(map(str, y_star.ys))
if cached_key not in cache:
decoder_input = torch.tensor(
[y_star.ys[-1]], device=device
).reshape(1, 1)
decoder_out, decoder_state = model.decoder(
decoder_input,
y_star.decoder_state,
)
cache[cached_key] = (decoder_out, decoder_state)
else:
decoder_out, decoder_state = cache[cached_key]
logits = model.joiner(current_encoder_out, decoder_out)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
# If we choose blank here, add the new hypothesis to B.
# Otherwise, add the new hypothesis to A
# First, choose blank
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
# ys[:] returns a copy of ys
new_y_star = Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
# Caution: Use y_star.decoder_state here
decoder_state=y_star.decoder_state,
)
B.append(new_y_star)
# Second, choose other labels
for i, v in enumerate(log_prob.tolist()):
if i in (blank_id, sos_id):
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
decoder_state=decoder_state,
)
A.append(new_hyp)
u += 1
# check whether B contains more than "beam" elements more probable
# than the most probable in A
A_most_probable = max(A, key=lambda hyp: hyp.log_prob)
B = sorted(
[hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob],
key=lambda hyp: hyp.log_prob,
reverse=True,
)
if len(B) >= beam:
B = B[:beam]
break
t += 1
best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:]))
ys = best_hyp.ys[1:] # [1:] to remove the blank
return ys

View File

@ -0,0 +1,457 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./transducer_lstm/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_lstm/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer_lstm/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_lstm/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search
from decoder import Decoder
from encoder import LstmEncoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=77,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=55,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_lstm/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=5,
help="Used only when --decoding-method is beam_search",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"encoder_hidden_size": 1024,
"num_encoder_layers": 4,
"proj_size": 512,
"vgg_frontend": False,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
encoder = LstmEncoder(
num_features=params.feature_dim,
hidden_size=params.encoder_hidden_size,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
else:
return {f"beam_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search")
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
model.device = device
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,101 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
import torch.nn as nn
# TODO(fangjun): Support switching between LSTM and GRU
class Decoder(nn.Module):
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
sos_id: int,
num_layers: int,
hidden_dim: int,
output_dim: int,
embedding_dropout: float = 0.0,
rnn_dropout: float = 0.0,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
sos_id:
The ID of the SOS symbol.
num_layers:
Number of LSTM layers.
hidden_dim:
Hidden dimension of LSTM layers.
output_dim:
Output dimension of the decoder.
embedding_dropout:
Dropout rate for the embedding layer.
rnn_dropout:
Dropout for LSTM layers.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.embedding_dropout = nn.Dropout(embedding_dropout)
# TODO(fangjun): Use layer normalized LSTM
self.rnn = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=rnn_dropout,
)
self.blank_id = blank_id
self.sos_id = sos_id
self.output_linear = nn.Linear(hidden_dim, output_dim)
def forward(
self,
y: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
y:
A 2-D tensor of shape (N, U) with BOS prepended.
states:
A tuple of two tensors containing the states information of
LSTM layers in this decoder.
Returns:
Return a tuple containing:
- rnn_output, a tensor of shape (N, U, C)
- (h, c), containing the state information for LSTM layers.
Both are of shape (num_layers, N, C)
"""
embeding_out = self.embedding(y)
embeding_out = self.embedding_dropout(embeding_out)
rnn_out, (h, c) = self.rnn(embeding_out, states)
out = self.output_linear(rnn_out)
return out, (h, c)

View File

@ -0,0 +1,115 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class LstmEncoder(EncoderInterface):
def __init__(
self,
num_features: int,
hidden_size: int,
output_dim: int,
subsampling_factor: int = 4,
num_encoder_layers: int = 12,
dropout: float = 0.1,
vgg_frontend: bool = False,
proj_size: int = 0,
):
super().__init__()
real_hidden_size = proj_size if proj_size > 0 else hidden_size
assert (
subsampling_factor == 4
), "Only subsampling_factor==4 is supported at present"
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
if vgg_frontend:
self.encoder_embed = VggSubsampling(num_features, real_hidden_size)
else:
self.encoder_embed = Conv2dSubsampling(
num_features, real_hidden_size
)
self.rnn = nn.LSTM(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_encoder_layers,
bias=True,
proj_size=proj_size,
batch_first=True,
dropout=dropout,
bidirectional=False,
)
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(real_hidden_size, output_dim),
)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
x = self.encoder_embed(x)
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(1) == lengths.max().item(), (
x.size(1),
lengths.max(),
)
if False:
# It is commented out as DPP complains that not all parameters are
# used. Need more checks later for the reason.
#
# Caution: We assume the dataloader returns utterances with
# duration being sorted in decreasing order
packed_x = pack_padded_sequence(
input=x,
lengths=lengths.cpu(),
batch_first=True,
enforce_sorted=True,
)
packed_rnn_out, _ = self.rnn(packed_x)
rnn_out, _ = pad_packed_sequence(packed_x, batch_first=True)
else:
rnn_out, _ = self.rnn(x)
logits = self.encoder_output_layer(rnn_out)
return logits, lengths

View File

@ -0,0 +1,43 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
class EncoderInterface(nn.Module):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (batch_size, input_seq_len, num_features)
containing the input features.
x_lens:
A tensor of shape (batch_size,) containing the number of frames
in `x` before padding.
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
containing unnormalized probabilities, i.e., the output of a
linear layer.
- encoder_out_lens, a tensor of shape (batch_size,) containing
the number of frames in `encoder_out` before padding.
"""
raise NotImplementedError("Please implement it in a subclass")

View File

@ -0,0 +1,55 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.output_linear = nn.Linear(input_dim, output_dim)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, C).
decoder_out:
Output from the decoder. Its shape is (N, U, C).
Returns:
Return a tensor of shape (N, T, U, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == decoder_out.size(2)
encoder_out = encoder_out.unsqueeze(2)
# Now encoder_out is (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1)
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out
logit = F.relu(logit)
output = self.output_linear(logit)
return output

View File

@ -0,0 +1,127 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2
import torch
import torch.nn as nn
import torchaudio
import torchaudio.functional
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
two attributes: `blank_id` and `sos_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface)
assert hasattr(decoder, "blank_id")
assert hasattr(decoder, "sos_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
Returns:
Return the transducer loss.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id
sos_y = add_sos(y, sos_id=sos_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
decoder_out, _ = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
loss = torchaudio.functional.rnnt_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
)
return loss

View File

@ -0,0 +1,104 @@
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* min(step ** (-0.5), step * self.warmup ** (-1.5))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)

View File

@ -0,0 +1 @@
../transducer/subsampling.py

View File

@ -0,0 +1,48 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_lstm/test_encoder.py
"""
from encoder import LstmEncoder
def test_encoder():
encoder = LstmEncoder(
num_features=80,
hidden_size=1024,
proj_size=512,
output_dim=512,
subsampling_factor=4,
num_encoder_layers=12,
)
num_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
print(num_params)
# 93979284
# 66427392
def main():
test_encoder()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,738 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2"
./transducer_lstm/train.py \
--world-size 3 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_lstm/exp \
--full-libri 1 \
--max-duration 400 \
--lr-factor 3
"""
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from encoder import LstmEncoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from model import Transducer
from noam import Noam
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
transducer_lstm/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_lstm/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lr-factor",
type=float,
default=3.0,
help="The lr_factor for Noam optimizer",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"encoder_hidden_size": 1024,
"num_encoder_layers": 4,
"proj_size": 512,
"vgg_frontend": False,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
encoder = LstmEncoder(
num_features=params.feature_dim,
hidden_size=params.encoder_hidden_size,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y)
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 800
params.warm_step = 8000
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
checkpoints = load_checkpoint_if_available(params=params, model=model)
num_param = sum([p.numel() for p in model.parameters() if p.requires_grad])
logging.info(f"Number of model parameters: {num_param}")
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Noam(
model.parameters(),
model_size=params.encoder_hidden_size,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,22 @@
## Introduction
The decoder, i.e., the prediction network, is from
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
(Rnn-Transducer with Stateless Prediction Network)
You can use the following command to start the training:
```bash
cd egs/librispeech/ASR
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp \
--full-libri 1 \
--max-duration 250 \
--lr-factor 2.5
```

View File

@ -0,0 +1 @@
../transducer/asr_datamodule.py

View File

@ -0,0 +1,212 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
hyp = []
max_u = 1000 # terminate after this number of steps
u = 0
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out)
# logits is (1, 1, 1, vocab_size)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
# TODO: Use logits.argmax()
y = log_prob.argmax()
if y != blank_id:
hyp.append(y.item())
y = y.reshape(1, 1)
decoder_out, (h, c) = model.decoder(y, (h, c))
u += 1
else:
t += 1
return hyp
@dataclass
class Hypothesis:
ys: List[int] # the predicted sequences so far
log_prob: float # The log prob of ys
# Optional decoder state. We assume it is LSTM for now,
# so the state is a tuple (h, c)
decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 5,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)]
max_u = 20000 # terminate after this number of steps
u = 0
cache: Dict[
str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
] = {}
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
A = B
B = []
# for hyp in A:
# for h in A:
# if h.ys == hyp.ys[:-1]:
# # update the score of hyp
# decoder_input = torch.tensor(
# [h.ys[-1]], device=device
# ).reshape(1, 1)
# decoder_out, _ = model.decoder(
# decoder_input, h.decoder_state
# )
# logits = model.joiner(current_encoder_out, decoder_out)
# log_prob = logits.log_softmax(dim=-1)
# log_prob = log_prob.squeeze()
# hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item()
while u < max_u:
y_star = max(A, key=lambda hyp: hyp.log_prob)
A.remove(y_star)
# Note: y_star.ys is unhashable, i.e., cannot be used
# as a key into a dict
cached_key = "_".join(map(str, y_star.ys))
if cached_key not in cache:
decoder_input = torch.tensor(
[y_star.ys[-1]], device=device
).reshape(1, 1)
decoder_out, decoder_state = model.decoder(
decoder_input,
y_star.decoder_state,
)
cache[cached_key] = (decoder_out, decoder_state)
else:
decoder_out, decoder_state = cache[cached_key]
logits = model.joiner(current_encoder_out, decoder_out)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
# If we choose blank here, add the new hypothesis to B.
# Otherwise, add the new hypothesis to A
# First, choose blank
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
# ys[:] returns a copy of ys
new_y_star = Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
# Caution: Use y_star.decoder_state here
decoder_state=y_star.decoder_state,
)
B.append(new_y_star)
# Second, choose other labels
for i, v in enumerate(log_prob.tolist()):
if i in (blank_id, sos_id):
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
decoder_state=decoder_state,
)
A.append(new_hyp)
u += 1
# check whether B contains more than "beam" elements more probable
# than the most probable in A
A_most_probable = max(A, key=lambda hyp: hyp.log_prob)
B = sorted(
[hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob],
key=lambda hyp: hyp.log_prob,
reverse=True,
)
if len(B) >= beam:
B = B[:beam]
break
t += 1
best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:]))
ys = best_hyp.ys[1:] # [1:] to remove the blank
return ys

View File

@ -0,0 +1,922 @@
#!/usr/bin/env python3
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from transducer.transformer import Transformer
from icefall.utils import make_pad_mask
class Conformer(Transformer):
"""
Args:
num_features (int): Number of input features
output_dim (int): Number of output dimension
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension
nhead (int): number of head
dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend.
"""
def __init__(
self,
num_features: int,
output_dim: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
output_dim=output_dim,
subsampling_factor=subsampling_factor,
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers,
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
use_feat_batchnorm=use_feat_batchnorm,
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
cnn_module_kernel,
normalize_before,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = nn.LayerNorm(d_model)
else:
# Note: TorchScript detects that self.after_norm could be used inside forward()
# and throws an error without this change.
self.after_norm = identity
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
if self.normalize_before:
x = self.after_norm(x)
logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths
class ConformerEncoderLayer(nn.Module):
"""
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
normalize_before: whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = encoder_layer(src, pos_emb)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0
)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout)
self.normalize_before = normalize_before
def forward(
self,
src: Tensor,
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
# macaron style feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# multi-headed self-attention module
residual = src
if self.normalize_before:
src = self.norm_mha(src)
src_att = self.self_attn(
src,
src,
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
if not self.normalize_before:
src = self.norm_conv(src)
# feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
if self.normalize_before:
src = self.norm_final(src)
return src
class ConformerEncoder(nn.TransformerEncoder):
r"""ConformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = conformer_encoder(src, pos_emb)
"""
def __init__(
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
) -> None:
super(ConformerEncoder, self).__init__(
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
)
def forward(
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
output = src
for mod in self.layers:
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
if self.norm is not None:
output = self.norm(output)
return output
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
Examples::
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self._reset_parameters()
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
return self.multi_head_attention_forward(
query,
key,
value,
pos_emb,
self.embed_dim,
self.num_heads,
self.in_proj.weight,
self.in_proj.bias,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
)
def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
(note: time2 has the same value as time1, but it is for
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time1),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
def multi_head_attention_forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
length, N is the batch size, E is the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
or attn_mask.dtype == torch.float64
or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8
or attn_mask.dtype == torch.bool
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
attn_mask.dtype
)
if attn_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
)
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz
)
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
key_padding_mask.size(1), src_len
)
q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(
1, 2
) # (batch, head, time1, d_k)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
attn_output_weights = (
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
def identity(x):
return x

View File

@ -0,0 +1,65 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class Decoder(nn.Module):
"""This class implements the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.blank_id = blank_id
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U) with blank prepended.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embeding_out = self.embedding(y)
return embeding_out

View File

@ -0,0 +1,43 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
class EncoderInterface(nn.Module):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A tensor of shape (batch_size, input_seq_len, num_features)
containing the input features.
x_lens:
A tensor of shape (batch_size,) containing the number of frames
in `x` before padding.
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
containing unnormalized probabilities, i.e., the output of a
linear layer.
- encoder_out_lens, a tensor of shape (batch_size,) containing
the number of frames in `encoder_out` before padding.
"""
raise NotImplementedError("Please implement it in a subclass")

View File

@ -0,0 +1,55 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.output_linear = nn.Linear(input_dim, output_dim)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, C).
decoder_out:
Output from the decoder. Its shape is (N, U, C).
Returns:
Return a tensor of shape (N, T, U, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == decoder_out.size(2)
encoder_out = encoder_out.unsqueeze(2)
# Now encoder_out is (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1)
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out
logit = F.relu(logit)
output = self.output_linear(logit)
return output

View File

@ -0,0 +1,125 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2
import torch
import torch.nn as nn
import torchaudio
import torchaudio.functional
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
Returns:
Return the transducer loss.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
decoder_out = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
loss = torchaudio.functional.rnnt_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
)
return loss

View File

@ -0,0 +1 @@
../transducer/subsampling.py

View File

@ -0,0 +1,734 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp \
--full-libri 1 \
--max-duration 250 \
--lr-factor 2.5
"""
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=78,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
transducer_stateless/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lr-factor",
type=float,
default=5.0,
help="The lr_factor for Noam optimizer",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y)
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 800
params.warm_step = 8000
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,429 @@
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transducer.encoder_interface import EncoderInterface
from transducer.subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import make_pad_mask
class Transformer(EncoderInterface):
def __init__(
self,
num_features: int,
output_dim: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
"""
Args:
num_features:
The input dimension of the model.
output_dim:
The output dimension of the model.
subsampling_factor:
Number of output frames is num_in_frames // subsampling_factor.
Currently, subsampling_factor MUST be 4.
d_model:
Attention dimension.
nhead:
Number of heads in multi-head attention.
Must satisfy d_model // nhead == 0.
dim_feedforward:
The output dimension of the feedforward layers in encoder.
num_encoder_layers:
Number of encoder layers.
dropout:
Dropout in encoder.
normalize_before:
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features
self.output_dim = output_dim
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
if vgg_frontend:
self.encoder_embed = VggSubsampling(num_features, d_model)
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = PositionalEncoding(d_model, dropout)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
normalize_before=normalize_before,
)
if normalize_before:
encoder_norm = nn.LayerNorm(d_model)
else:
encoder_norm = None
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_encoder_layers,
norm=encoder_norm,
)
# TODO(fangjun): remove dropout
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), nn.Linear(d_model, output_dim)
)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths
class TransformerEncoderLayer(nn.Module):
"""
Modified from torch.nn.TransformerEncoderLayer.
Add support of normalize_before,
i.e., use layer_norm before the first block.
Args:
d_model:
the number of expected features in the input (required).
nhead:
the number of heads in the multiheadattention models (required).
dim_feedforward:
the dimension of the feedforward network model (default=2048).
dropout:
the dropout value (default=0.1).
activation:
the activation function of intermediate layer, relu or
gelu (default=relu).
normalize_before:
whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: str = "relu",
normalize_before: bool = True,
) -> None:
super(TransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = nn.functional.relu
super(TransformerEncoderLayer, self).__setstate__(state)
def forward(
self,
src: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional)
Shape:
src: (S, N, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length,
N is the batch size, E is the feature number
"""
residual = src
if self.normalize_before:
src = self.norm1(src)
src2 = self.self_attn(
src,
src,
src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout1(src2)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src2)
if not self.normalize_before:
src = self.norm2(src)
return src
def _get_activation_fn(activation: str):
if activation == "relu":
return nn.functional.relu
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module):
"""This class implements the positional encoding
proposed in the following paper:
- Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
Note::
1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
= exp(-1* 2i / d_model * log(100000))
= exp(2i * -(log(10000) / d_model))
"""
def __init__(self, d_model: int, dropout: float = 0.1) -> None:
"""
Args:
d_model:
Embedding dimension.
dropout:
Dropout probability to be applied to the output of this module.
"""
super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout)
# not doing: self.pe = None because of errors thrown by torchscript
self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required.
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
is (N, T, d_model). If T > T1, then we change the shape of self.pe
to (N, T, d_model). Otherwise, nothing is done.
Args:
x:
It is a tensor of shape (N, T, C).
Returns:
Return None.
"""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Add positional encoding.
Args:
x:
Its shape is (N, T, C)
Returns:
Return a tensor of shape (N, T, C)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :]
return self.dropout(x)
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* min(step ** (-0.5), step * self.warmup ** (-1.5))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)

View File

@ -20,7 +20,7 @@ import torch
from transducer.model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[str]:
"""
Args:
model:
@ -42,7 +42,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
T = encoder_out.size(1)
t = 0
hyp = []
max_u = 1000 # terminte after this number of steps
max_u = 1000 # terminate after this number of steps
u = 0
while t < T and u < max_u: