mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Use conformer/transformer model as encoder.
This commit is contained in:
parent
f802758fca
commit
f5199d37c4
@ -22,20 +22,21 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
from transducer.transformer import Transformer
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
class Conformer(Transformer):
|
class Conformer(Transformer):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
num_features (int): Number of input features
|
num_features (int): Number of input features
|
||||||
num_classes (int): Number of output classes
|
output_dim (int): Number of output dimension
|
||||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||||
d_model (int): attention dimension
|
d_model (int): attention dimension
|
||||||
nhead (int): number of head
|
nhead (int): number of head
|
||||||
dim_feedforward (int): feedforward dimention
|
dim_feedforward (int): feedforward dimention
|
||||||
num_encoder_layers (int): number of encoder layers
|
num_encoder_layers (int): number of encoder layers
|
||||||
num_decoder_layers (int): number of decoder layers
|
|
||||||
dropout (float): dropout rate
|
dropout (float): dropout rate
|
||||||
cnn_module_kernel (int): Kernel size of convolution module
|
cnn_module_kernel (int): Kernel size of convolution module
|
||||||
normalize_before (bool): whether to use layer_norm before the first block.
|
normalize_before (bool): whether to use layer_norm before the first block.
|
||||||
@ -45,13 +46,12 @@ class Conformer(Transformer):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
num_classes: int,
|
output_dim: int,
|
||||||
subsampling_factor: int = 4,
|
subsampling_factor: int = 4,
|
||||||
d_model: int = 256,
|
d_model: int = 256,
|
||||||
nhead: int = 4,
|
nhead: int = 4,
|
||||||
dim_feedforward: int = 2048,
|
dim_feedforward: int = 2048,
|
||||||
num_encoder_layers: int = 12,
|
num_encoder_layers: int = 12,
|
||||||
num_decoder_layers: int = 6,
|
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
@ -60,13 +60,12 @@ class Conformer(Transformer):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
num_classes=num_classes,
|
output_dim=output_dim,
|
||||||
subsampling_factor=subsampling_factor,
|
subsampling_factor=subsampling_factor,
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
nhead=nhead,
|
nhead=nhead,
|
||||||
dim_feedforward=dim_feedforward,
|
dim_feedforward=dim_feedforward,
|
||||||
num_encoder_layers=num_encoder_layers,
|
num_encoder_layers=num_encoder_layers,
|
||||||
num_decoder_layers=num_decoder_layers,
|
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
normalize_before=normalize_before,
|
normalize_before=normalize_before,
|
||||||
vgg_frontend=vgg_frontend,
|
vgg_frontend=vgg_frontend,
|
||||||
@ -92,38 +91,45 @@ class Conformer(Transformer):
|
|||||||
# and throws an error without this change.
|
# and throws an error without this change.
|
||||||
self.after_norm = identity
|
self.after_norm = identity
|
||||||
|
|
||||||
def run_encoder(
|
def forward(
|
||||||
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
The model input. Its shape is (N, T, C).
|
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||||
supervisions:
|
x_lens:
|
||||||
Supervision in lhotse format.
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
`x` before padding.
|
||||||
CAUTION: It contains length information, i.e., start and number of
|
|
||||||
frames, before subsampling
|
|
||||||
It is read directly from the batch, without any sorting. It is used
|
|
||||||
to compute encoder padding mask, which is used as memory key padding
|
|
||||||
mask for the decoder.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
Return a tuple containing 2 tensors:
|
||||||
Tensor: Mask tensor of dimension (batch_size, input_length)
|
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||||
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
|
of frames in `logits` before padding.
|
||||||
"""
|
"""
|
||||||
|
if self.use_feat_batchnorm:
|
||||||
|
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||||
|
x = self.feat_batchnorm(x)
|
||||||
|
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||||
|
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
|
||||||
if mask is not None:
|
# Caution: We assume the subsampling factor is 4!
|
||||||
mask = mask.to(x.device)
|
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
assert x.size(0) == lengths.max().item()
|
||||||
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
|
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
|
||||||
|
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
x = self.after_norm(x)
|
x = self.after_norm(x)
|
||||||
|
|
||||||
return x, mask
|
logits = self.encoder_output_layer(x)
|
||||||
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
|
return logits, lengths
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoderLayer(nn.Module):
|
class ConformerEncoderLayer(nn.Module):
|
||||||
|
@ -21,16 +21,6 @@ import torch.nn as nn
|
|||||||
|
|
||||||
|
|
||||||
class EncoderInterface(nn.Module):
|
class EncoderInterface(nn.Module):
|
||||||
def __init__(self, num_features: int, output_dim: int):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
num_features:
|
|
||||||
The dimension of the input features.
|
|
||||||
output_dim:
|
|
||||||
Output dimension of the model.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
60
egs/librispeech/ASR/transducer/test_conformer.py
Executable file
60
egs/librispeech/ASR/transducer/test_conformer.py
Executable file
@ -0,0 +1,60 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./transducer/test_conformer.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transducer.conformer import Conformer
|
||||||
|
|
||||||
|
|
||||||
|
def test_conformer():
|
||||||
|
output_dim = 1024
|
||||||
|
conformer = Conformer(
|
||||||
|
num_features=80,
|
||||||
|
output_dim=output_dim,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=512,
|
||||||
|
nhead=8,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
num_encoder_layers=12,
|
||||||
|
use_feat_batchnorm=True,
|
||||||
|
)
|
||||||
|
N = 3
|
||||||
|
T = 100
|
||||||
|
C = 80
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
x_lens = torch.tensor([50, 100, 80])
|
||||||
|
logits, logit_lens = conformer(x, x_lens)
|
||||||
|
|
||||||
|
expected_T = ((T - 1) // 2 - 1) // 2
|
||||||
|
assert logits.shape == (N, expected_T, output_dim)
|
||||||
|
assert logit_lens.max().item() == expected_T
|
||||||
|
print(logits.shape)
|
||||||
|
print(logit_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_conformer()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -18,7 +18,7 @@
|
|||||||
"""
|
"""
|
||||||
To run this file, do:
|
To run this file, do:
|
||||||
|
|
||||||
cd icefall/egs/yesno/ASR
|
cd icefall/egs/librispeech/ASR
|
||||||
python ./transducer/test_decoder.py
|
python ./transducer/test_decoder.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
"""
|
"""
|
||||||
To run this file, do:
|
To run this file, do:
|
||||||
|
|
||||||
cd icefall/egs/yesno/ASR
|
cd icefall/egs/librispeech/ASR
|
||||||
python ./transducer/test_joiner.py
|
python ./transducer/test_joiner.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -15,6 +15,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./transducer/test_rnn.py
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transducer.rnn import (
|
from transducer.rnn import (
|
||||||
|
88
egs/librispeech/ASR/transducer/test_transducer.py
Executable file
88
egs/librispeech/ASR/transducer/test_transducer.py
Executable file
@ -0,0 +1,88 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./transducer/test_transducer.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
from transducer.conformer import Conformer
|
||||||
|
from transducer.decoder import Decoder
|
||||||
|
from transducer.joiner import Joiner
|
||||||
|
from transducer.model import Transducer
|
||||||
|
|
||||||
|
|
||||||
|
def test_transducer():
|
||||||
|
# encoder params
|
||||||
|
input_dim = 10
|
||||||
|
output_dim = 20
|
||||||
|
|
||||||
|
# decoder params
|
||||||
|
vocab_size = 3
|
||||||
|
blank_id = 0
|
||||||
|
sos_id = 2
|
||||||
|
embedding_dim = 128
|
||||||
|
num_layers = 2
|
||||||
|
|
||||||
|
encoder = Conformer(
|
||||||
|
num_features=input_dim,
|
||||||
|
output_dim=output_dim,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=512,
|
||||||
|
nhead=8,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
num_encoder_layers=12,
|
||||||
|
use_feat_batchnorm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = Decoder(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
blank_id=blank_id,
|
||||||
|
sos_id=sos_id,
|
||||||
|
num_layers=num_layers,
|
||||||
|
hidden_dim=output_dim,
|
||||||
|
embedding_dropout=0.0,
|
||||||
|
rnn_dropout=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = Joiner(output_dim, vocab_size)
|
||||||
|
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
|
||||||
|
|
||||||
|
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
|
||||||
|
N = y.dim0
|
||||||
|
T = 50
|
||||||
|
|
||||||
|
x = torch.rand(N, T, input_dim)
|
||||||
|
x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32)
|
||||||
|
x_lens[0] = T
|
||||||
|
|
||||||
|
loss = transducer(x, x_lens, y)
|
||||||
|
print(loss)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_transducer()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
60
egs/librispeech/ASR/transducer/test_transformer.py
Executable file
60
egs/librispeech/ASR/transducer/test_transformer.py
Executable file
@ -0,0 +1,60 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./transducer/test_transformer.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transducer.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformer():
|
||||||
|
output_dim = 1024
|
||||||
|
transformer = Transformer(
|
||||||
|
num_features=80,
|
||||||
|
output_dim=output_dim,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=512,
|
||||||
|
nhead=8,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
num_encoder_layers=12,
|
||||||
|
use_feat_batchnorm=True,
|
||||||
|
)
|
||||||
|
N = 3
|
||||||
|
T = 100
|
||||||
|
C = 80
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
x_lens = torch.tensor([50, 100, 80])
|
||||||
|
logits, logit_lens = transformer(x, x_lens)
|
||||||
|
|
||||||
|
expected_T = ((T - 1) // 2 - 1) // 2
|
||||||
|
assert logits.shape == (N, expected_T, output_dim)
|
||||||
|
assert logit_lens.max().item() == expected_T
|
||||||
|
print(logits.shape)
|
||||||
|
print(logit_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_transformer()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -16,29 +16,26 @@
|
|||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from label_smoothing import LabelSmoothingLoss
|
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from transducer.encoder_interface import EncoderInterface
|
||||||
|
|
||||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
from icefall.utils import make_pad_mask
|
||||||
Supervisions = Dict[str, torch.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
class Transformer(EncoderInterface):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
num_classes: int,
|
output_dim: int,
|
||||||
subsampling_factor: int = 4,
|
subsampling_factor: int = 4,
|
||||||
d_model: int = 256,
|
d_model: int = 256,
|
||||||
nhead: int = 4,
|
nhead: int = 4,
|
||||||
dim_feedforward: int = 2048,
|
dim_feedforward: int = 2048,
|
||||||
num_encoder_layers: int = 12,
|
num_encoder_layers: int = 12,
|
||||||
num_decoder_layers: int = 6,
|
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
@ -48,7 +45,7 @@ class Transformer(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
num_features:
|
num_features:
|
||||||
The input dimension of the model.
|
The input dimension of the model.
|
||||||
num_classes:
|
output_dim:
|
||||||
The output dimension of the model.
|
The output dimension of the model.
|
||||||
subsampling_factor:
|
subsampling_factor:
|
||||||
Number of output frames is num_in_frames // subsampling_factor.
|
Number of output frames is num_in_frames // subsampling_factor.
|
||||||
@ -59,13 +56,11 @@ class Transformer(nn.Module):
|
|||||||
Number of heads in multi-head attention.
|
Number of heads in multi-head attention.
|
||||||
Must satisfy d_model // nhead == 0.
|
Must satisfy d_model // nhead == 0.
|
||||||
dim_feedforward:
|
dim_feedforward:
|
||||||
The output dimension of the feedforward layers in encoder/decoder.
|
The output dimension of the feedforward layers in encoder.
|
||||||
num_encoder_layers:
|
num_encoder_layers:
|
||||||
Number of encoder layers.
|
Number of encoder layers.
|
||||||
num_decoder_layers:
|
|
||||||
Number of decoder layers.
|
|
||||||
dropout:
|
dropout:
|
||||||
Dropout in encoder/decoder.
|
Dropout in encoder.
|
||||||
normalize_before:
|
normalize_before:
|
||||||
If True, use pre-layer norm; False to use post-layer norm.
|
If True, use pre-layer norm; False to use post-layer norm.
|
||||||
vgg_frontend:
|
vgg_frontend:
|
||||||
@ -79,16 +74,16 @@ class Transformer(nn.Module):
|
|||||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
||||||
|
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.num_classes = num_classes
|
self.output_dim = output_dim
|
||||||
self.subsampling_factor = subsampling_factor
|
self.subsampling_factor = subsampling_factor
|
||||||
if subsampling_factor != 4:
|
if subsampling_factor != 4:
|
||||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||||
|
|
||||||
# self.encoder_embed converts the input of shape (N, T, num_classes)
|
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||||
# to the shape (N, T//subsampling_factor, d_model).
|
# to the shape (N, T//subsampling_factor, d_model).
|
||||||
# That is, it does two things simultaneously:
|
# That is, it does two things simultaneously:
|
||||||
# (1) subsampling: T -> T//subsampling_factor
|
# (1) subsampling: T -> T//subsampling_factor
|
||||||
# (2) embedding: num_classes -> d_model
|
# (2) embedding: num_features -> d_model
|
||||||
if vgg_frontend:
|
if vgg_frontend:
|
||||||
self.encoder_embed = VggSubsampling(num_features, d_model)
|
self.encoder_embed = VggSubsampling(num_features, d_model)
|
||||||
else:
|
else:
|
||||||
@ -117,279 +112,45 @@ class Transformer(nn.Module):
|
|||||||
|
|
||||||
# TODO(fangjun): remove dropout
|
# TODO(fangjun): remove dropout
|
||||||
self.encoder_output_layer = nn.Sequential(
|
self.encoder_output_layer = nn.Sequential(
|
||||||
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
nn.Dropout(p=dropout), nn.Linear(d_model, output_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_decoder_layers > 0:
|
|
||||||
self.decoder_num_class = (
|
|
||||||
self.num_classes
|
|
||||||
) # bpe model already has sos/eos symbol
|
|
||||||
|
|
||||||
self.decoder_embed = nn.Embedding(
|
|
||||||
num_embeddings=self.decoder_num_class, embedding_dim=d_model
|
|
||||||
)
|
|
||||||
self.decoder_pos = PositionalEncoding(d_model, dropout)
|
|
||||||
|
|
||||||
decoder_layer = TransformerDecoderLayer(
|
|
||||||
d_model=d_model,
|
|
||||||
nhead=nhead,
|
|
||||||
dim_feedforward=dim_feedforward,
|
|
||||||
dropout=dropout,
|
|
||||||
normalize_before=normalize_before,
|
|
||||||
)
|
|
||||||
|
|
||||||
if normalize_before:
|
|
||||||
decoder_norm = nn.LayerNorm(d_model)
|
|
||||||
else:
|
|
||||||
decoder_norm = None
|
|
||||||
|
|
||||||
self.decoder = nn.TransformerDecoder(
|
|
||||||
decoder_layer=decoder_layer,
|
|
||||||
num_layers=num_decoder_layers,
|
|
||||||
norm=decoder_norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.decoder_output_layer = torch.nn.Linear(
|
|
||||||
d_model, self.decoder_num_class
|
|
||||||
)
|
|
||||||
|
|
||||||
self.decoder_criterion = LabelSmoothingLoss()
|
|
||||||
else:
|
|
||||||
self.decoder_criterion = None
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
The input tensor. Its shape is (N, T, C).
|
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||||
supervision:
|
x_lens:
|
||||||
Supervision in lhotse format.
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
`x` before padding.
|
||||||
(CAUTION: It contains length information, i.e., start and number of
|
|
||||||
frames, before subsampling)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing 3 tensors:
|
Return a tuple containing 2 tensors:
|
||||||
- CTC output for ctc decoding. Its shape is (N, T, C)
|
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||||
- Encoder output with shape (T, N, C). It can be used as key and
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
value for the decoder.
|
of frames in `logits` before padding.
|
||||||
- Encoder output padding mask. It can be used as
|
|
||||||
memory_key_padding_mask for the decoder. Its shape is (N, T).
|
|
||||||
It is None if `supervision` is None.
|
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
if self.use_feat_batchnorm:
|
||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||||
x = self.feat_batchnorm(x)
|
x = self.feat_batchnorm(x)
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
|
||||||
x, supervision
|
|
||||||
)
|
|
||||||
x = self.ctc_output(encoder_memory)
|
|
||||||
return x, encoder_memory, memory_key_padding_mask
|
|
||||||
|
|
||||||
def run_encoder(
|
|
||||||
self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
||||||
"""Run the transformer encoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x:
|
|
||||||
The model input. Its shape is (N, T, C).
|
|
||||||
supervisions:
|
|
||||||
Supervision in lhotse format.
|
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
|
||||||
CAUTION: It contains length information, i.e., start and number of
|
|
||||||
frames, before subsampling
|
|
||||||
It is read directly from the batch, without any sorting. It is used
|
|
||||||
to compute the encoder padding mask, which is used as memory key
|
|
||||||
padding mask for the decoder.
|
|
||||||
Returns:
|
|
||||||
Return a tuple with two tensors:
|
|
||||||
- The encoder output, with shape (T, N, C)
|
|
||||||
- encoder padding mask, with shape (N, T).
|
|
||||||
The mask is None if `supervisions` is None.
|
|
||||||
It is used as memory key padding mask in the decoder.
|
|
||||||
"""
|
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x = self.encoder_pos(x)
|
x = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
|
||||||
mask = mask.to(x.device) if mask is not None else None
|
# Caution: We assume the subsampling factor is 4!
|
||||||
|
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||||
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
|
mask = make_pad_mask(lengths)
|
||||||
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
|
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
|
||||||
|
|
||||||
return x, mask
|
logits = self.encoder_output_layer(x)
|
||||||
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
|
return logits, lengths
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x:
|
|
||||||
The output tensor from the transformer encoder.
|
|
||||||
Its shape is (T, N, C)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a tensor that can be used for CTC decoding.
|
|
||||||
Its shape is (N, T, C)
|
|
||||||
"""
|
|
||||||
x = self.encoder_output_layer(x)
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
|
||||||
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
|
|
||||||
return x
|
|
||||||
|
|
||||||
@torch.jit.export
|
|
||||||
def decoder_forward(
|
|
||||||
self,
|
|
||||||
memory: torch.Tensor,
|
|
||||||
memory_key_padding_mask: torch.Tensor,
|
|
||||||
token_ids: List[List[int]],
|
|
||||||
sos_id: int,
|
|
||||||
eos_id: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
memory:
|
|
||||||
It's the output of the encoder with shape (T, N, C)
|
|
||||||
memory_key_padding_mask:
|
|
||||||
The padding mask from the encoder.
|
|
||||||
token_ids:
|
|
||||||
A list-of-list IDs. Each sublist contains IDs for an utterance.
|
|
||||||
The IDs can be either phone IDs or word piece IDs.
|
|
||||||
sos_id:
|
|
||||||
sos token id
|
|
||||||
eos_id:
|
|
||||||
eos token id
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A scalar, the **sum** of label smoothing loss over utterances
|
|
||||||
in the batch without any normalization.
|
|
||||||
"""
|
|
||||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
|
||||||
ys_in = [torch.tensor(y) for y in ys_in]
|
|
||||||
ys_in_pad = pad_sequence(
|
|
||||||
ys_in, batch_first=True, padding_value=float(eos_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
|
||||||
ys_out = [torch.tensor(y) for y in ys_out]
|
|
||||||
ys_out_pad = pad_sequence(
|
|
||||||
ys_out, batch_first=True, padding_value=float(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
device = memory.device
|
|
||||||
ys_in_pad = ys_in_pad.to(device)
|
|
||||||
ys_out_pad = ys_out_pad.to(device)
|
|
||||||
|
|
||||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
|
|
||||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
|
||||||
# TODO: Use length information to create the decoder padding mask
|
|
||||||
# We set the first column to False since the first column in ys_in_pad
|
|
||||||
# contains sos_id, which is the same as eos_id in our current setting.
|
|
||||||
tgt_key_padding_mask[:, 0] = False
|
|
||||||
|
|
||||||
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
|
|
||||||
tgt = self.decoder_pos(tgt)
|
|
||||||
tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
|
||||||
pred_pad = self.decoder(
|
|
||||||
tgt=tgt,
|
|
||||||
memory=memory,
|
|
||||||
tgt_mask=tgt_mask,
|
|
||||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
|
||||||
) # (T, N, C)
|
|
||||||
pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
|
||||||
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
|
|
||||||
|
|
||||||
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
|
|
||||||
|
|
||||||
return decoder_loss
|
|
||||||
|
|
||||||
@torch.jit.export
|
|
||||||
def decoder_nll(
|
|
||||||
self,
|
|
||||||
memory: torch.Tensor,
|
|
||||||
memory_key_padding_mask: torch.Tensor,
|
|
||||||
token_ids: List[torch.Tensor],
|
|
||||||
sos_id: int,
|
|
||||||
eos_id: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
memory:
|
|
||||||
It's the output of the encoder with shape (T, N, C)
|
|
||||||
memory_key_padding_mask:
|
|
||||||
The padding mask from the encoder.
|
|
||||||
token_ids:
|
|
||||||
A list-of-list IDs (e.g., word piece IDs).
|
|
||||||
Each sublist represents an utterance.
|
|
||||||
sos_id:
|
|
||||||
The token ID for SOS.
|
|
||||||
eos_id:
|
|
||||||
The token ID for EOS.
|
|
||||||
Returns:
|
|
||||||
A 2-D tensor of shape (len(token_ids), max_token_length)
|
|
||||||
representing the cross entropy loss (i.e., negative log-likelihood).
|
|
||||||
"""
|
|
||||||
# The common part between this function and decoder_forward could be
|
|
||||||
# extracted as a separate function.
|
|
||||||
if isinstance(token_ids[0], torch.Tensor):
|
|
||||||
# This branch is executed by torchscript in C++.
|
|
||||||
# See https://github.com/k2-fsa/k2/pull/870
|
|
||||||
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
|
|
||||||
token_ids = [tolist(t) for t in token_ids]
|
|
||||||
|
|
||||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
|
||||||
ys_in = [torch.tensor(y) for y in ys_in]
|
|
||||||
ys_in_pad = pad_sequence(
|
|
||||||
ys_in, batch_first=True, padding_value=float(eos_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
|
||||||
ys_out = [torch.tensor(y) for y in ys_out]
|
|
||||||
ys_out_pad = pad_sequence(
|
|
||||||
ys_out, batch_first=True, padding_value=float(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
device = memory.device
|
|
||||||
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
|
||||||
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
|
|
||||||
|
|
||||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
|
|
||||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
|
||||||
# TODO: Use length information to create the decoder padding mask
|
|
||||||
# We set the first column to False since the first column in ys_in_pad
|
|
||||||
# contains sos_id, which is the same as eos_id in our current setting.
|
|
||||||
tgt_key_padding_mask[:, 0] = False
|
|
||||||
|
|
||||||
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
|
||||||
tgt = self.decoder_pos(tgt)
|
|
||||||
tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
|
||||||
pred_pad = self.decoder(
|
|
||||||
tgt=tgt,
|
|
||||||
memory=memory,
|
|
||||||
tgt_mask=tgt_mask,
|
|
||||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
|
||||||
) # (T, B, F)
|
|
||||||
pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F)
|
|
||||||
pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F)
|
|
||||||
# nll: negative log-likelihood
|
|
||||||
nll = torch.nn.functional.cross_entropy(
|
|
||||||
pred_pad.view(-1, self.decoder_num_class),
|
|
||||||
ys_out_pad.view(-1),
|
|
||||||
ignore_index=-1,
|
|
||||||
reduction="none",
|
|
||||||
)
|
|
||||||
|
|
||||||
nll = nll.view(pred_pad.shape[0], -1)
|
|
||||||
|
|
||||||
return nll
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
@ -494,138 +255,6 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
return src
|
return src
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(nn.Module):
|
|
||||||
"""
|
|
||||||
Modified from torch.nn.TransformerDecoderLayer.
|
|
||||||
Add support of normalize_before,
|
|
||||||
i.e., use layer_norm before the first block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
d_model:
|
|
||||||
the number of expected features in the input (required).
|
|
||||||
nhead:
|
|
||||||
the number of heads in the multiheadattention models (required).
|
|
||||||
dim_feedforward:
|
|
||||||
the dimension of the feedforward network model (default=2048).
|
|
||||||
dropout:
|
|
||||||
the dropout value (default=0.1).
|
|
||||||
activation:
|
|
||||||
the activation function of intermediate layer, relu or
|
|
||||||
gelu (default=relu).
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
|
||||||
>>> memory = torch.rand(10, 32, 512)
|
|
||||||
>>> tgt = torch.rand(20, 32, 512)
|
|
||||||
>>> out = decoder_layer(tgt, memory)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
d_model: int,
|
|
||||||
nhead: int,
|
|
||||||
dim_feedforward: int = 2048,
|
|
||||||
dropout: float = 0.1,
|
|
||||||
activation: str = "relu",
|
|
||||||
normalize_before: bool = True,
|
|
||||||
) -> None:
|
|
||||||
super(TransformerDecoderLayer, self).__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
|
|
||||||
self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
|
|
||||||
# Implementation of Feedforward model
|
|
||||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(d_model)
|
|
||||||
self.norm2 = nn.LayerNorm(d_model)
|
|
||||||
self.norm3 = nn.LayerNorm(d_model)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
self.dropout3 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
|
|
||||||
self.normalize_before = normalize_before
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
|
||||||
if "activation" not in state:
|
|
||||||
state["activation"] = nn.functional.relu
|
|
||||||
super(TransformerDecoderLayer, self).__setstate__(state)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
tgt: torch.Tensor,
|
|
||||||
memory: torch.Tensor,
|
|
||||||
tgt_mask: Optional[torch.Tensor] = None,
|
|
||||||
memory_mask: Optional[torch.Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
|
||||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Pass the inputs (and mask) through the decoder layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tgt:
|
|
||||||
the sequence to the decoder layer (required).
|
|
||||||
memory:
|
|
||||||
the sequence from the last layer of the encoder (required).
|
|
||||||
tgt_mask:
|
|
||||||
the mask for the tgt sequence (optional).
|
|
||||||
memory_mask:
|
|
||||||
the mask for the memory sequence (optional).
|
|
||||||
tgt_key_padding_mask:
|
|
||||||
the mask for the tgt keys per batch (optional).
|
|
||||||
memory_key_padding_mask:
|
|
||||||
the mask for the memory keys per batch (optional).
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
tgt: (T, N, E).
|
|
||||||
memory: (S, N, E).
|
|
||||||
tgt_mask: (T, T).
|
|
||||||
memory_mask: (T, S).
|
|
||||||
tgt_key_padding_mask: (N, T).
|
|
||||||
memory_key_padding_mask: (N, S).
|
|
||||||
S is the source sequence length, T is the target sequence length,
|
|
||||||
N is the batch size, E is the feature number
|
|
||||||
"""
|
|
||||||
residual = tgt
|
|
||||||
if self.normalize_before:
|
|
||||||
tgt = self.norm1(tgt)
|
|
||||||
tgt2 = self.self_attn(
|
|
||||||
tgt,
|
|
||||||
tgt,
|
|
||||||
tgt,
|
|
||||||
attn_mask=tgt_mask,
|
|
||||||
key_padding_mask=tgt_key_padding_mask,
|
|
||||||
)[0]
|
|
||||||
tgt = residual + self.dropout1(tgt2)
|
|
||||||
if not self.normalize_before:
|
|
||||||
tgt = self.norm1(tgt)
|
|
||||||
|
|
||||||
residual = tgt
|
|
||||||
if self.normalize_before:
|
|
||||||
tgt = self.norm2(tgt)
|
|
||||||
tgt2 = self.src_attn(
|
|
||||||
tgt,
|
|
||||||
memory,
|
|
||||||
memory,
|
|
||||||
attn_mask=memory_mask,
|
|
||||||
key_padding_mask=memory_key_padding_mask,
|
|
||||||
)[0]
|
|
||||||
tgt = residual + self.dropout2(tgt2)
|
|
||||||
if not self.normalize_before:
|
|
||||||
tgt = self.norm2(tgt)
|
|
||||||
|
|
||||||
residual = tgt
|
|
||||||
if self.normalize_before:
|
|
||||||
tgt = self.norm3(tgt)
|
|
||||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
|
||||||
tgt = residual + self.dropout3(tgt2)
|
|
||||||
if not self.normalize_before:
|
|
||||||
tgt = self.norm3(tgt)
|
|
||||||
return tgt
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation: str):
|
def _get_activation_fn(activation: str):
|
||||||
if activation == "relu":
|
if activation == "relu":
|
||||||
return nn.functional.relu
|
return nn.functional.relu
|
||||||
@ -798,149 +427,3 @@ class Noam(object):
|
|||||||
self.optimizer.load_state_dict(state_dict["optimizer"])
|
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||||
else:
|
else:
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
def encoder_padding_mask(
|
|
||||||
max_len: int, supervisions: Optional[Supervisions] = None
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
"""Make mask tensor containing indexes of padded part.
|
|
||||||
|
|
||||||
TODO::
|
|
||||||
This function **assumes** that the model uses
|
|
||||||
a subsampling factor of 4. We should remove that
|
|
||||||
assumption later.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_len:
|
|
||||||
Maximum length of input features.
|
|
||||||
CAUTION: It is the length after subsampling.
|
|
||||||
supervisions:
|
|
||||||
Supervision in lhotse format.
|
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
|
||||||
(CAUTION: It contains length information, i.e., start and number of
|
|
||||||
frames, before subsampling)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Mask tensor of dimension (batch_size, input_length),
|
|
||||||
True denote the masked indices.
|
|
||||||
"""
|
|
||||||
if supervisions is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
|
||||||
(
|
|
||||||
supervisions["sequence_idx"],
|
|
||||||
supervisions["start_frame"],
|
|
||||||
supervisions["num_frames"],
|
|
||||||
),
|
|
||||||
1,
|
|
||||||
).to(torch.int32)
|
|
||||||
|
|
||||||
lengths = [
|
|
||||||
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
|
|
||||||
]
|
|
||||||
for idx in range(supervision_segments.size(0)):
|
|
||||||
# Note: TorchScript doesn't allow to unpack tensors as tuples
|
|
||||||
sequence_idx = supervision_segments[idx, 0].item()
|
|
||||||
start_frame = supervision_segments[idx, 1].item()
|
|
||||||
num_frames = supervision_segments[idx, 2].item()
|
|
||||||
lengths[sequence_idx] = start_frame + num_frames
|
|
||||||
|
|
||||||
lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
|
|
||||||
bs = int(len(lengths))
|
|
||||||
seq_range = torch.arange(0, max_len, dtype=torch.int64)
|
|
||||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
|
|
||||||
# Note: TorchScript doesn't implement Tensor.new()
|
|
||||||
seq_length_expand = torch.tensor(
|
|
||||||
lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype
|
|
||||||
).unsqueeze(-1)
|
|
||||||
mask = seq_range_expand >= seq_length_expand
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def decoder_padding_mask(
|
|
||||||
ys_pad: torch.Tensor, ignore_id: int = -1
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Generate a length mask for input.
|
|
||||||
|
|
||||||
The masked position are filled with True,
|
|
||||||
Unmasked positions are filled with False.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ys_pad:
|
|
||||||
padded tensor of dimension (batch_size, input_length).
|
|
||||||
ignore_id:
|
|
||||||
the ignored number (the padding number) in ys_pad
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor:
|
|
||||||
a bool tensor of the same shape as the input tensor.
|
|
||||||
"""
|
|
||||||
ys_mask = ys_pad == ignore_id
|
|
||||||
return ys_mask
|
|
||||||
|
|
||||||
|
|
||||||
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
|
|
||||||
"""Generate a square mask for the sequence. The masked positions are
|
|
||||||
filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
|
||||||
The mask can be used for masked self-attention.
|
|
||||||
|
|
||||||
For instance, if sz is 3, it returns::
|
|
||||||
|
|
||||||
tensor([[0., -inf, -inf],
|
|
||||||
[0., 0., -inf],
|
|
||||||
[0., 0., 0]])
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sz: mask size
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A square mask of dimension (sz, sz)
|
|
||||||
"""
|
|
||||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
|
||||||
mask = (
|
|
||||||
mask.float()
|
|
||||||
.masked_fill(mask == 0, float("-inf"))
|
|
||||||
.masked_fill(mask == 1, float(0.0))
|
|
||||||
)
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
|
|
||||||
"""Prepend sos_id to each utterance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_ids:
|
|
||||||
A list-of-list of token IDs. Each sublist contains
|
|
||||||
token IDs (e.g., word piece IDs) of an utterance.
|
|
||||||
sos_id:
|
|
||||||
The ID of the SOS token.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
Return a new list-of-list, where each sublist starts
|
|
||||||
with SOS ID.
|
|
||||||
"""
|
|
||||||
return [[sos_id] + utt for utt in token_ids]
|
|
||||||
|
|
||||||
|
|
||||||
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
|
||||||
"""Append eos_id to each utterance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_ids:
|
|
||||||
A list-of-list of token IDs. Each sublist contains
|
|
||||||
token IDs (e.g., word piece IDs) of an utterance.
|
|
||||||
eos_id:
|
|
||||||
The ID of the EOS token.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
Return a new list-of-list, where each sublist ends
|
|
||||||
with EOS ID.
|
|
||||||
"""
|
|
||||||
return [utt + [eos_id] for utt in token_ids]
|
|
||||||
|
|
||||||
|
|
||||||
def tolist(t: torch.Tensor) -> List[int]:
|
|
||||||
"""Used by jit"""
|
|
||||||
return torch.jit.annotate(List[int], t.tolist())
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user