mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
f8af9f0400
commit
c875d7c1c1
Binary file not shown.
@ -381,6 +381,130 @@ class Transformer(nn.Module):
|
||||
return nll
|
||||
|
||||
|
||||
class TransformerEncoder(Module):
|
||||
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
||||
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
||||
|
||||
Args:
|
||||
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
norm: the layer normalization component (optional).
|
||||
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
||||
(and convert back on output). This will improve the overall performance of
|
||||
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> out = transformer_encoder(src)
|
||||
"""
|
||||
__constants__ = ['norm']
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True):
|
||||
super(TransformerEncoder, self).__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.enable_nested_tensor = enable_nested_tensor
|
||||
self.mask_check = mask_check
|
||||
|
||||
[docs] def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
see the docs in Transformer class.
|
||||
"""
|
||||
if src_key_padding_mask is not None:
|
||||
_skpm_dtype = src_key_padding_mask.dtype
|
||||
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
|
||||
raise AssertionError(
|
||||
"only bool and floating types of key_padding_mask are supported")
|
||||
output = src
|
||||
convert_to_nested = False
|
||||
first_layer = self.layers[0]
|
||||
src_key_padding_mask_for_layers = src_key_padding_mask
|
||||
why_not_sparsity_fast_path = ''
|
||||
str_first_layer = "self.layers[0]"
|
||||
if not isinstance(first_layer, torch.nn.TransformerEncoderLayer):
|
||||
why_not_sparsity_fast_path = f"{str_first_layer} was not TransformerEncoderLayer"
|
||||
elif first_layer.norm_first :
|
||||
why_not_sparsity_fast_path = f"{str_first_layer}.norm_first was True"
|
||||
elif first_layer.training:
|
||||
why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
|
||||
elif not first_layer.self_attn.batch_first:
|
||||
why_not_sparsity_fast_path = f" {str_first_layer}.self_attn.batch_first was not True"
|
||||
elif not first_layer.self_attn._qkv_same_embed_dim:
|
||||
why_not_sparsity_fast_path = f"{str_first_layer}.self_attn._qkv_same_embed_dim was not True"
|
||||
elif not first_layer.activation_relu_or_gelu:
|
||||
why_not_sparsity_fast_path = f" {str_first_layer}.activation_relu_or_gelu was not True"
|
||||
elif not (first_layer.norm1.eps == first_layer.norm2.eps) :
|
||||
why_not_sparsity_fast_path = f"{str_first_layer}.norm1.eps was not equal to {str_first_layer}.norm2.eps"
|
||||
elif not src.dim() == 3:
|
||||
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
|
||||
elif not self.enable_nested_tensor:
|
||||
why_not_sparsity_fast_path = "enable_nested_tensor was not True"
|
||||
elif src_key_padding_mask is None:
|
||||
why_not_sparsity_fast_path = "src_key_padding_mask was None"
|
||||
elif (((not hasattr(self, "mask_check")) or self.mask_check)
|
||||
and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
|
||||
why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
|
||||
elif output.is_nested:
|
||||
why_not_sparsity_fast_path = "NestedTensor input is not supported"
|
||||
elif mask is not None:
|
||||
why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
|
||||
elif first_layer.self_attn.num_heads % 2 == 1:
|
||||
why_not_sparsity_fast_path = "num_head is odd"
|
||||
elif torch.is_autocast_enabled():
|
||||
why_not_sparsity_fast_path = "autocast is enabled"
|
||||
|
||||
if not why_not_sparsity_fast_path:
|
||||
tensor_args = (
|
||||
src,
|
||||
first_layer.self_attn.in_proj_weight,
|
||||
first_layer.self_attn.in_proj_bias,
|
||||
first_layer.self_attn.out_proj.weight,
|
||||
first_layer.self_attn.out_proj.bias,
|
||||
first_layer.norm1.weight,
|
||||
first_layer.norm1.bias,
|
||||
first_layer.norm2.weight,
|
||||
first_layer.norm2.bias,
|
||||
first_layer.linear1.weight,
|
||||
first_layer.linear1.bias,
|
||||
first_layer.linear2.weight,
|
||||
first_layer.linear2.bias,
|
||||
)
|
||||
|
||||
if torch.overrides.has_torch_function(tensor_args):
|
||||
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
|
||||
elif not (src.is_cuda or 'cpu' in str(src.device)):
|
||||
why_not_sparsity_fast_path = "src is neither CUDA nor CPU"
|
||||
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
|
||||
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
|
||||
"input/output projection weights or biases requires_grad")
|
||||
|
||||
if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
|
||||
convert_to_nested = True
|
||||
output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
|
||||
src_key_padding_mask_for_layers = None
|
||||
|
||||
for mod in self.layers:
|
||||
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers)
|
||||
|
||||
if convert_to_nested:
|
||||
output = output.to_padded_tensor(0.)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
Modified from torch.nn.TransformerEncoderLayer.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user