mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Reformat conformer.py by black
This commit is contained in:
parent
9808d30282
commit
8c75c0abeb
@ -1,21 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
# Apache 2.0
|
||||
|
||||
import math
|
||||
import warnings
|
||||
@ -94,8 +80,8 @@ class Conformer(Transformer):
|
||||
if self.normalize_before and self.is_espnet_structure:
|
||||
self.after_norm = nn.LayerNorm(d_model)
|
||||
else:
|
||||
# Note: TorchScript detects that self.after_norm could be used
|
||||
# inside forward() and throws an error without this change.
|
||||
# Note: TorchScript detects that self.after_norm could be used inside forward()
|
||||
# and throws an error without this change.
|
||||
self.after_norm = identity
|
||||
|
||||
def run_encoder(
|
||||
@ -107,7 +93,7 @@ class Conformer(Transformer):
|
||||
The model input. Its shape is [N, T, C].
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling
|
||||
It is read directly from the batch, without any sorting. It is used
|
||||
@ -133,9 +119,8 @@ class Conformer(Transformer):
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""ConformerEncoderLayer is made up of self-attn, feedforward and
|
||||
convolution networks.
|
||||
|
||||
"""
|
||||
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
||||
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
|
||||
|
||||
Args:
|
||||
@ -208,7 +193,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
"""Pass the input through the encoder layer.
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
@ -315,8 +301,7 @@ class ConformerEncoder(nn.TransformerEncoder):
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length,
|
||||
N is the batch size, E is the feature number
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
"""
|
||||
output = src
|
||||
@ -411,7 +396,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
+ 1: self.pe.size(1) // 2
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x.size(1),
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
@ -485,46 +470,40 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
pos_emb: Positional embedding tensor
|
||||
key_padding_mask: if provided, specified padding elements in the key
|
||||
will be ignored by the attention. When given a binary mask and
|
||||
a value is True, the corresponding value on the attention layer
|
||||
will be ignored. When given a byte mask and a value is non-zero,
|
||||
the corresponding value on the attention layer will be ignored.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. When given a binary mask and a value is True,
|
||||
the corresponding value on the attention layer will be ignored. When given
|
||||
a byte mask and a value is non-zero, the corresponding value on the attention
|
||||
layer will be ignored
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions.
|
||||
A 2D mask will be broadcasted for all the batches while a 3D mask
|
||||
allows to specify a different mask for the entries of each batch.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
- Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length,
|
||||
N is the batch size, E is the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length,
|
||||
N is the batch size, E is the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length,
|
||||
N is the batch size, E is the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length,
|
||||
N is the batch size, E is the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size,
|
||||
S is the source sequence length. If a ByteTensor is provided,
|
||||
the non-zero positions will be ignored while the position
|
||||
with the zero positions will be unchanged. If a BoolTensor is provided,
|
||||
the positions with the value of ``True`` will be ignored while
|
||||
the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length,
|
||||
S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)`
|
||||
where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensure that position
|
||||
i is allowed to attend the unmasked positions.
|
||||
If a ByteTensor is provided, the non-zero positions are not
|
||||
allowed to attend while the zero positions will be unchanged.
|
||||
If a BoolTensor is provided, positions with ``True`` is not allowed
|
||||
to attend while ``False`` values will be unchanged.
|
||||
If a FloatTensor is provided, it will be added to the attention weight.
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
- Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length,
|
||||
N is the batch size, E is the embedding dimension.
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
@ -603,41 +582,34 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
the corresponding value on the attention layer will be filled with -inf.
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions.
|
||||
A 2D mask will be broadcasted for all the batches while a 3D mask
|
||||
allows to specify a different mask for the entries of each batch.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length,
|
||||
N is the batch size, E is the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length,
|
||||
N is the batch size, E is the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length,
|
||||
N is the batch size, E is the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where
|
||||
L is the target sequencelength, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size,
|
||||
S is the source sequence length. If a ByteTensor is provided,
|
||||
the non-zero positions will be ignored while the zero positions
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
|
||||
length, N is the batch size, E is the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
||||
will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the
|
||||
value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length,
|
||||
S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)`
|
||||
where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensures that position
|
||||
i is allowed to attend the unmasked positions.
|
||||
If a ByteTensor is provided, the non-zero positions are not
|
||||
allowed to attend while the zero positions will be unchanged.
|
||||
If a BoolTensor is provided, positions with ``True`` are not
|
||||
allowed to attend while ``False`` values will be unchanged.
|
||||
If a FloatTensor is provided, it will be added to the attention weight.
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length,
|
||||
N is the batch size, E is the embedding dimension.
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user