Reformat conformer.py by black

This commit is contained in:
pkufool 2021-08-23 09:25:17 +08:00
parent 9808d30282
commit 8c75c0abeb

View File

@ -1,21 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
# Apache 2.0
import math import math
import warnings import warnings
@ -94,8 +80,8 @@ class Conformer(Transformer):
if self.normalize_before and self.is_espnet_structure: if self.normalize_before and self.is_espnet_structure:
self.after_norm = nn.LayerNorm(d_model) self.after_norm = nn.LayerNorm(d_model)
else: else:
# Note: TorchScript detects that self.after_norm could be used # Note: TorchScript detects that self.after_norm could be used inside forward()
# inside forward() and throws an error without this change. # and throws an error without this change.
self.after_norm = identity self.after_norm = identity
def run_encoder( def run_encoder(
@ -107,7 +93,7 @@ class Conformer(Transformer):
The model input. Its shape is [N, T, C]. The model input. Its shape is [N, T, C].
supervisions: supervisions:
Supervision in lhotse format. Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of CAUTION: It contains length information, i.e., start and number of
frames, before subsampling frames, before subsampling
It is read directly from the batch, without any sorting. It is used It is read directly from the batch, without any sorting. It is used
@ -133,9 +119,8 @@ class Conformer(Transformer):
class ConformerEncoderLayer(nn.Module): class ConformerEncoderLayer(nn.Module):
"""ConformerEncoderLayer is made up of self-attn, feedforward and """
convolution networks. ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
See: "Conformer: Convolution-augmented Transformer for Speech Recognition" See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
Args: Args:
@ -208,7 +193,8 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
"""Pass the input through the encoder layer. """
Pass the input through the encoder layer.
Args: Args:
src: the sequence to the encoder layer (required). src: the sequence to the encoder layer (required).
@ -315,8 +301,7 @@ class ConformerEncoder(nn.TransformerEncoder):
pos_emb: (N, 2*S-1, E) pos_emb: (N, 2*S-1, E)
mask: (S, S). mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
N is the batch size, E is the feature number
""" """
output = src output = src
@ -411,7 +396,7 @@ class RelPositionalEncoding(torch.nn.Module):
:, :,
self.pe.size(1) // 2 self.pe.size(1) // 2
- x.size(1) - x.size(1)
+ 1: self.pe.size(1) // 2 + 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1), + x.size(1),
] ]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
@ -485,46 +470,40 @@ class RelPositionMultiheadAttention(nn.Module):
Args: Args:
query, key, value: map a query and a set of key-value pairs to an output. query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor pos_emb: Positional embedding tensor
key_padding_mask: if provided, specified padding elements in the key key_padding_mask: if provided, specified padding elements in the key will
will be ignored by the attention. When given a binary mask and be ignored by the attention. When given a binary mask and a value is True,
a value is True, the corresponding value on the attention layer the corresponding value on the attention layer will be ignored. When given
will be ignored. When given a byte mask and a value is non-zero, a byte mask and a value is non-zero, the corresponding value on the attention
the corresponding value on the attention layer will be ignored. layer will be ignored
need_weights: output attn_output_weights. need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
A 2D mask will be broadcasted for all the batches while a 3D mask the batches while a 3D mask allows to specify a different mask for the entries of each batch.
allows to specify a different mask for the entries of each batch.
Shape: Shape:
- Inputs: - Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
S is the source sequence length. If a ByteTensor is provided, If a ByteTensor is provided, the non-zero positions will be ignored while the position
the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
with the zero positions will be unchanged. If a BoolTensor is provided, value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
the positions with the value of ``True`` will be ignored while - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
the position with the value of ``False`` will be unchanged. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
where N is the batch size, L is the target sequence length, while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
S is the source sequence length. attn_mask ensure that position is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
i is allowed to attend the unmasked positions. is provided, it will be added to the attention weight.
If a ByteTensor is provided, the non-zero positions are not
allowed to attend while the zero positions will be unchanged.
If a BoolTensor is provided, positions with ``True`` is not allowed
to attend while ``False`` values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
- Outputs: - Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
N is the batch size, E is the embedding dimension. E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size, - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length. L is the target sequence length, S is the source sequence length.
""" """
@ -603,41 +582,34 @@ class RelPositionMultiheadAttention(nn.Module):
be ignored by the attention. This is an binary mask. When the value is True, be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf. the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights. need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
A 2D mask will be broadcasted for all the batches while a 3D mask the batches while a 3D mask allows to specify a different mask for the entries of each batch.
allows to specify a different mask for the entries of each batch.
Shape: Shape:
Inputs: Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
N is the batch size, E is the embedding dimension. the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
L is the target sequencelength, N is the batch size, length, N is the batch size, E is the embedding dimension.
E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
- key_padding_mask: :math:`(N, S)` where N is the batch size, If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
S is the source sequence length. If a ByteTensor is provided,
the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
where N is the batch size, L is the target sequence length, positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
S is the source sequence length. attn_mask ensures that position while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
i is allowed to attend the unmasked positions. are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
If a ByteTensor is provided, the non-zero positions are not is provided, it will be added to the attention weight.
allowed to attend while the zero positions will be unchanged.
If a BoolTensor is provided, positions with ``True`` are not
allowed to attend while ``False`` values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
Outputs: Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
N is the batch size, E is the embedding dimension. E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size, - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length. L is the target sequence length, S is the source sequence length.
""" """