Reformat conformer.py by black

This commit is contained in:
pkufool 2021-08-23 09:25:17 +08:00
parent 9808d30282
commit 8c75c0abeb

View File

@ -1,21 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
# Apache 2.0
import math import math
import warnings import warnings
@ -29,18 +15,18 @@ from transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer): class Conformer(Transformer):
""" """
Args: Args:
num_features (int): Number of input features num_features (int): Number of input features
num_classes (int): Number of output classes num_classes (int): Number of output classes
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension d_model (int): attention dimension
nhead (int): number of head nhead (int): number of head
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
num_decoder_layers (int): number of decoder layers num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block. normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend. vgg_frontend (bool): whether to use vgg frontend.
""" """
def __init__( def __init__(
@ -94,8 +80,8 @@ class Conformer(Transformer):
if self.normalize_before and self.is_espnet_structure: if self.normalize_before and self.is_espnet_structure:
self.after_norm = nn.LayerNorm(d_model) self.after_norm = nn.LayerNorm(d_model)
else: else:
# Note: TorchScript detects that self.after_norm could be used # Note: TorchScript detects that self.after_norm could be used inside forward()
# inside forward() and throws an error without this change. # and throws an error without this change.
self.after_norm = identity self.after_norm = identity
def run_encoder( def run_encoder(
@ -107,7 +93,7 @@ class Conformer(Transformer):
The model input. Its shape is [N, T, C]. The model input. Its shape is [N, T, C].
supervisions: supervisions:
Supervision in lhotse format. Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of CAUTION: It contains length information, i.e., start and number of
frames, before subsampling frames, before subsampling
It is read directly from the batch, without any sorting. It is used It is read directly from the batch, without any sorting. It is used
@ -133,18 +119,17 @@ class Conformer(Transformer):
class ConformerEncoderLayer(nn.Module): class ConformerEncoderLayer(nn.Module):
"""ConformerEncoderLayer is made up of self-attn, feedforward and """
convolution networks. ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
See: "Conformer: Convolution-augmented Transformer for Speech Recognition" See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
Args: Args:
d_model: the number of expected features in the input (required). d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required). nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048). dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1). dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module. cnn_module_kernel (int): Kernel size of convolution module.
normalize_before: whether to use layer_norm before the first block. normalize_before: whether to use layer_norm before the first block.
Examples:: Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -208,7 +193,8 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
"""Pass the input through the encoder layer. """
Pass the input through the encoder layer.
Args: Args:
src: the sequence to the encoder layer (required). src: the sequence to the encoder layer (required).
@ -315,8 +301,7 @@ class ConformerEncoder(nn.TransformerEncoder):
pos_emb: (N, 2*S-1, E) pos_emb: (N, 2*S-1, E)
mask: (S, S). mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
N is the batch size, E is the feature number
""" """
output = src output = src
@ -411,7 +396,7 @@ class RelPositionalEncoding(torch.nn.Module):
:, :,
self.pe.size(1) // 2 self.pe.size(1) // 2
- x.size(1) - x.size(1)
+ 1: self.pe.size(1) // 2 + 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1), + x.size(1),
] ]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
@ -485,48 +470,42 @@ class RelPositionMultiheadAttention(nn.Module):
Args: Args:
query, key, value: map a query and a set of key-value pairs to an output. query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor pos_emb: Positional embedding tensor
key_padding_mask: if provided, specified padding elements in the key key_padding_mask: if provided, specified padding elements in the key will
will be ignored by the attention. When given a binary mask and be ignored by the attention. When given a binary mask and a value is True,
a value is True, the corresponding value on the attention layer the corresponding value on the attention layer will be ignored. When given
will be ignored. When given a byte mask and a value is non-zero, a byte mask and a value is non-zero, the corresponding value on the attention
the corresponding value on the attention layer will be ignored. layer will be ignored
need_weights: output attn_output_weights. need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
A 2D mask will be broadcasted for all the batches while a 3D mask the batches while a 3D mask allows to specify a different mask for the entries of each batch.
allows to specify a different mask for the entries of each batch.
Shape: Shape:
- Inputs: - Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
S is the source sequence length. If a ByteTensor is provided, If a ByteTensor is provided, the non-zero positions will be ignored while the position
the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
with the zero positions will be unchanged. If a BoolTensor is provided, value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
the positions with the value of ``True`` will be ignored while - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
the position with the value of ``False`` will be unchanged. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
where N is the batch size, L is the target sequence length, while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
S is the source sequence length. attn_mask ensure that position is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
i is allowed to attend the unmasked positions. is provided, it will be added to the attention weight.
If a ByteTensor is provided, the non-zero positions are not
allowed to attend while the zero positions will be unchanged.
If a BoolTensor is provided, positions with ``True`` is not allowed
to attend while ``False`` values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
- Outputs: - Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
N is the batch size, E is the embedding dimension. E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size, - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length. L is the target sequence length, S is the source sequence length.
""" """
return self.multi_head_attention_forward( return self.multi_head_attention_forward(
query, query,
@ -603,43 +582,36 @@ class RelPositionMultiheadAttention(nn.Module):
be ignored by the attention. This is an binary mask. When the value is True, be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf. the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights. need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
A 2D mask will be broadcasted for all the batches while a 3D mask the batches while a 3D mask allows to specify a different mask for the entries of each batch.
allows to specify a different mask for the entries of each batch.
Shape: Shape:
Inputs: Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
L is the target sequencelength, N is the batch size, length, N is the batch size, E is the embedding dimension.
E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
- key_padding_mask: :math:`(N, S)` where N is the batch size, If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
S is the source sequence length. If a ByteTensor is provided, will be unchanged. If a BoolTensor is provided, the positions with the
the non-zero positions will be ignored while the zero positions value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
will be unchanged. If a BoolTensor is provided, the positions with the - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
value of ``True`` will be ignored while the position with the 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
value of ``False`` will be unchanged. S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
where N is the batch size, L is the target sequence length, are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
S is the source sequence length. attn_mask ensures that position is provided, it will be added to the attention weight.
i is allowed to attend the unmasked positions.
If a ByteTensor is provided, the non-zero positions are not
allowed to attend while the zero positions will be unchanged.
If a BoolTensor is provided, positions with ``True`` are not
allowed to attend while ``False`` values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
Outputs: Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
N is the batch size, E is the embedding dimension. E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size, - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length. L is the target sequence length, S is the source sequence length.
""" """
tgt_len, bsz, embed_dim = query.size() tgt_len, bsz, embed_dim = query.size()