mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 04:22:21 +00:00
Reformat conformer.py by black
This commit is contained in:
parent
9808d30282
commit
8c75c0abeb
@ -1,21 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
# 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||||
|
# Apache 2.0
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
@ -94,8 +80,8 @@ class Conformer(Transformer):
|
|||||||
if self.normalize_before and self.is_espnet_structure:
|
if self.normalize_before and self.is_espnet_structure:
|
||||||
self.after_norm = nn.LayerNorm(d_model)
|
self.after_norm = nn.LayerNorm(d_model)
|
||||||
else:
|
else:
|
||||||
# Note: TorchScript detects that self.after_norm could be used
|
# Note: TorchScript detects that self.after_norm could be used inside forward()
|
||||||
# inside forward() and throws an error without this change.
|
# and throws an error without this change.
|
||||||
self.after_norm = identity
|
self.after_norm = identity
|
||||||
|
|
||||||
def run_encoder(
|
def run_encoder(
|
||||||
@ -107,7 +93,7 @@ class Conformer(Transformer):
|
|||||||
The model input. Its shape is [N, T, C].
|
The model input. Its shape is [N, T, C].
|
||||||
supervisions:
|
supervisions:
|
||||||
Supervision in lhotse format.
|
Supervision in lhotse format.
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32
|
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||||
CAUTION: It contains length information, i.e., start and number of
|
CAUTION: It contains length information, i.e., start and number of
|
||||||
frames, before subsampling
|
frames, before subsampling
|
||||||
It is read directly from the batch, without any sorting. It is used
|
It is read directly from the batch, without any sorting. It is used
|
||||||
@ -133,9 +119,8 @@ class Conformer(Transformer):
|
|||||||
|
|
||||||
|
|
||||||
class ConformerEncoderLayer(nn.Module):
|
class ConformerEncoderLayer(nn.Module):
|
||||||
"""ConformerEncoderLayer is made up of self-attn, feedforward and
|
"""
|
||||||
convolution networks.
|
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
||||||
|
|
||||||
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
|
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -208,7 +193,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Pass the input through the encoder layer.
|
"""
|
||||||
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder layer (required).
|
src: the sequence to the encoder layer (required).
|
||||||
@ -315,8 +301,7 @@ class ConformerEncoder(nn.TransformerEncoder):
|
|||||||
pos_emb: (N, 2*S-1, E)
|
pos_emb: (N, 2*S-1, E)
|
||||||
mask: (S, S).
|
mask: (S, S).
|
||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, T is the target sequence length,
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
N is the batch size, E is the feature number
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
@ -411,7 +396,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
:,
|
:,
|
||||||
self.pe.size(1) // 2
|
self.pe.size(1) // 2
|
||||||
- x.size(1)
|
- x.size(1)
|
||||||
+ 1: self.pe.size(1) // 2
|
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||||
+ x.size(1),
|
+ x.size(1),
|
||||||
]
|
]
|
||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
@ -485,46 +470,40 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
query, key, value: map a query and a set of key-value pairs to an output.
|
query, key, value: map a query and a set of key-value pairs to an output.
|
||||||
pos_emb: Positional embedding tensor
|
pos_emb: Positional embedding tensor
|
||||||
key_padding_mask: if provided, specified padding elements in the key
|
key_padding_mask: if provided, specified padding elements in the key will
|
||||||
will be ignored by the attention. When given a binary mask and
|
be ignored by the attention. When given a binary mask and a value is True,
|
||||||
a value is True, the corresponding value on the attention layer
|
the corresponding value on the attention layer will be ignored. When given
|
||||||
will be ignored. When given a byte mask and a value is non-zero,
|
a byte mask and a value is non-zero, the corresponding value on the attention
|
||||||
the corresponding value on the attention layer will be ignored.
|
layer will be ignored
|
||||||
need_weights: output attn_output_weights.
|
need_weights: output attn_output_weights.
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions.
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
A 2D mask will be broadcasted for all the batches while a 3D mask
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
allows to specify a different mask for the entries of each batch.
|
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
- Inputs:
|
- Inputs:
|
||||||
- query: :math:`(L, N, E)` where L is the target sequence length,
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
N is the batch size, E is the embedding dimension.
|
the embedding dimension.
|
||||||
- key: :math:`(S, N, E)`, where S is the source sequence length,
|
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||||
N is the batch size, E is the embedding dimension.
|
the embedding dimension.
|
||||||
- value: :math:`(S, N, E)` where S is the source sequence length,
|
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||||
N is the batch size, E is the embedding dimension.
|
the embedding dimension.
|
||||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length,
|
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
N is the batch size, E is the embedding dimension.
|
the embedding dimension.
|
||||||
- key_padding_mask: :math:`(N, S)` where N is the batch size,
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||||
S is the source sequence length. If a ByteTensor is provided,
|
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||||
the non-zero positions will be ignored while the position
|
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||||
with the zero positions will be unchanged. If a BoolTensor is provided,
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||||
the positions with the value of ``True`` will be ignored while
|
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||||
the position with the value of ``False`` will be unchanged.
|
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length,
|
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
||||||
S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)`
|
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||||
where N is the batch size, L is the target sequence length,
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||||
S is the source sequence length. attn_mask ensure that position
|
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||||
i is allowed to attend the unmasked positions.
|
is provided, it will be added to the attention weight.
|
||||||
If a ByteTensor is provided, the non-zero positions are not
|
|
||||||
allowed to attend while the zero positions will be unchanged.
|
|
||||||
If a BoolTensor is provided, positions with ``True`` is not allowed
|
|
||||||
to attend while ``False`` values will be unchanged.
|
|
||||||
If a FloatTensor is provided, it will be added to the attention weight.
|
|
||||||
|
|
||||||
- Outputs:
|
- Outputs:
|
||||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length,
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||||
N is the batch size, E is the embedding dimension.
|
E is the embedding dimension.
|
||||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||||
L is the target sequence length, S is the source sequence length.
|
L is the target sequence length, S is the source sequence length.
|
||||||
"""
|
"""
|
||||||
@ -603,41 +582,34 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
be ignored by the attention. This is an binary mask. When the value is True,
|
be ignored by the attention. This is an binary mask. When the value is True,
|
||||||
the corresponding value on the attention layer will be filled with -inf.
|
the corresponding value on the attention layer will be filled with -inf.
|
||||||
need_weights: output attn_output_weights.
|
need_weights: output attn_output_weights.
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions.
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
A 2D mask will be broadcasted for all the batches while a 3D mask
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
allows to specify a different mask for the entries of each batch.
|
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
Inputs:
|
Inputs:
|
||||||
- query: :math:`(L, N, E)` where L is the target sequence length,
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||||
N is the batch size, E is the embedding dimension.
|
the embedding dimension.
|
||||||
- key: :math:`(S, N, E)`, where S is the source sequence length,
|
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||||
N is the batch size, E is the embedding dimension.
|
the embedding dimension.
|
||||||
- value: :math:`(S, N, E)` where S is the source sequence length,
|
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||||
N is the batch size, E is the embedding dimension.
|
the embedding dimension.
|
||||||
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where
|
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
|
||||||
L is the target sequencelength, N is the batch size,
|
length, N is the batch size, E is the embedding dimension.
|
||||||
E is the embedding dimension.
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||||
- key_padding_mask: :math:`(N, S)` where N is the batch size,
|
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
||||||
S is the source sequence length. If a ByteTensor is provided,
|
|
||||||
the non-zero positions will be ignored while the zero positions
|
|
||||||
will be unchanged. If a BoolTensor is provided, the positions with the
|
will be unchanged. If a BoolTensor is provided, the positions with the
|
||||||
value of ``True`` will be ignored while the position with the
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||||
value of ``False`` will be unchanged.
|
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length,
|
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||||
S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)`
|
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
||||||
where N is the batch size, L is the target sequence length,
|
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||||
S is the source sequence length. attn_mask ensures that position
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||||
i is allowed to attend the unmasked positions.
|
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||||
If a ByteTensor is provided, the non-zero positions are not
|
is provided, it will be added to the attention weight.
|
||||||
allowed to attend while the zero positions will be unchanged.
|
|
||||||
If a BoolTensor is provided, positions with ``True`` are not
|
|
||||||
allowed to attend while ``False`` values will be unchanged.
|
|
||||||
If a FloatTensor is provided, it will be added to the attention weight.
|
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length,
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||||
N is the batch size, E is the embedding dimension.
|
E is the embedding dimension.
|
||||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||||
L is the target sequence length, S is the source sequence length.
|
L is the target sequence length, S is the source sequence length.
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user