Reformat conformer.py by black

This commit is contained in:
pkufool 2021-08-23 09:25:17 +08:00
parent 9808d30282
commit 8c75c0abeb

View File

@ -1,21 +1,7 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
# Apache 2.0
import math
import warnings
@ -29,18 +15,18 @@ from transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer):
"""
Args:
num_features (int): Number of input features
num_classes (int): Number of output classes
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension
nhead (int): number of head
dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers
num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend.
num_features (int): Number of input features
num_classes (int): Number of output classes
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension
nhead (int): number of head
dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers
num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend.
"""
def __init__(
@ -94,8 +80,8 @@ class Conformer(Transformer):
if self.normalize_before and self.is_espnet_structure:
self.after_norm = nn.LayerNorm(d_model)
else:
# Note: TorchScript detects that self.after_norm could be used
# inside forward() and throws an error without this change.
# Note: TorchScript detects that self.after_norm could be used inside forward()
# and throws an error without this change.
self.after_norm = identity
def run_encoder(
@ -107,7 +93,7 @@ class Conformer(Transformer):
The model input. Its shape is [N, T, C].
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
@ -133,18 +119,17 @@ class Conformer(Transformer):
class ConformerEncoderLayer(nn.Module):
"""ConformerEncoderLayer is made up of self-attn, feedforward and
convolution networks.
"""
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
normalize_before: whether to use layer_norm before the first block.
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
normalize_before: whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -208,7 +193,8 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Pass the input through the encoder layer.
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
@ -315,8 +301,7 @@ class ConformerEncoder(nn.TransformerEncoder):
pos_emb: (N, 2*S-1, E)
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length,
N is the batch size, E is the feature number
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
output = src
@ -411,7 +396,7 @@ class RelPositionalEncoding(torch.nn.Module):
:,
self.pe.size(1) // 2
- x.size(1)
+ 1: self.pe.size(1) // 2
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
@ -485,48 +470,42 @@ class RelPositionMultiheadAttention(nn.Module):
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
key_padding_mask: if provided, specified padding elements in the key
will be ignored by the attention. When given a binary mask and
a value is True, the corresponding value on the attention layer
will be ignored. When given a byte mask and a value is non-zero,
the corresponding value on the attention layer will be ignored.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions.
A 2D mask will be broadcasted for all the batches while a 3D mask
allows to specify a different mask for the entries of each batch.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length,
N is the batch size, E is the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length,
N is the batch size, E is the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length,
N is the batch size, E is the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size,
S is the source sequence length. If a ByteTensor is provided,
the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided,
the positions with the value of ``True`` will be ignored while
the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length,
S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)`
where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position
i is allowed to attend the unmasked positions.
If a ByteTensor is provided, the non-zero positions are not
allowed to attend while the zero positions will be unchanged.
If a BoolTensor is provided, positions with ``True`` is not allowed
to attend while ``False`` values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length,
N is the batch size, E is the embedding dimension.
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
L is the target sequence length, S is the source sequence length.
"""
return self.multi_head_attention_forward(
query,
@ -603,43 +582,36 @@ class RelPositionMultiheadAttention(nn.Module):
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions.
A 2D mask will be broadcasted for all the batches while a 3D mask
allows to specify a different mask for the entries of each batch.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length,
N is the batch size, E is the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length,
N is the batch size, E is the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where
L is the target sequencelength, N is the batch size,
E is the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size,
S is the source sequence length. If a ByteTensor is provided,
the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the
value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length,
S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)`
where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position
i is allowed to attend the unmasked positions.
If a ByteTensor is provided, the non-zero positions are not
allowed to attend while the zero positions will be unchanged.
If a BoolTensor is provided, positions with ``True`` are not
allowed to attend while ``False`` values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
length, N is the batch size, E is the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length,
N is the batch size, E is the embedding dimension.
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
L is the target sequence length, S is the source sequence length.
"""
tgt_len, bsz, embed_dim = query.size()