diff --git a/egs/aishell/ASR/conformer_ctc/.transformer.py.swp b/egs/aishell/ASR/conformer_ctc/.transformer.py.swp index a4911e748..89e1f19e8 100644 Binary files a/egs/aishell/ASR/conformer_ctc/.transformer.py.swp and b/egs/aishell/ASR/conformer_ctc/.transformer.py.swp differ diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py index 06485f3b5..ffb3cd8c3 100644 --- a/egs/aishell/ASR/conformer_ctc/transformer.py +++ b/egs/aishell/ASR/conformer_ctc/transformer.py @@ -381,6 +381,130 @@ class Transformer(nn.Module): return nll +class TransformerEncoder(Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ['norm'] + + def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.enable_nested_tensor = enable_nested_tensor + self.mask_check = mask_check + +[docs] def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported") + output = src + convert_to_nested = False + first_layer = self.layers[0] + src_key_padding_mask_for_layers = src_key_padding_mask + why_not_sparsity_fast_path = '' + str_first_layer = "self.layers[0]" + if not isinstance(first_layer, torch.nn.TransformerEncoderLayer): + why_not_sparsity_fast_path = f"{str_first_layer} was not TransformerEncoderLayer" + elif first_layer.norm_first : + why_not_sparsity_fast_path = f"{str_first_layer}.norm_first was True" + elif first_layer.training: + why_not_sparsity_fast_path = f"{str_first_layer} was in training mode" + elif not first_layer.self_attn.batch_first: + why_not_sparsity_fast_path = f" {str_first_layer}.self_attn.batch_first was not True" + elif not first_layer.self_attn._qkv_same_embed_dim: + why_not_sparsity_fast_path = f"{str_first_layer}.self_attn._qkv_same_embed_dim was not True" + elif not first_layer.activation_relu_or_gelu: + why_not_sparsity_fast_path = f" {str_first_layer}.activation_relu_or_gelu was not True" + elif not (first_layer.norm1.eps == first_layer.norm2.eps) : + why_not_sparsity_fast_path = f"{str_first_layer}.norm1.eps was not equal to {str_first_layer}.norm2.eps" + elif not src.dim() == 3: + why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}" + elif not self.enable_nested_tensor: + why_not_sparsity_fast_path = "enable_nested_tensor was not True" + elif src_key_padding_mask is None: + why_not_sparsity_fast_path = "src_key_padding_mask was None" + elif (((not hasattr(self, "mask_check")) or self.mask_check) + and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())): + why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned" + elif output.is_nested: + why_not_sparsity_fast_path = "NestedTensor input is not supported" + elif mask is not None: + why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied" + elif first_layer.self_attn.num_heads % 2 == 1: + why_not_sparsity_fast_path = "num_head is odd" + elif torch.is_autocast_enabled(): + why_not_sparsity_fast_path = "autocast is enabled" + + if not why_not_sparsity_fast_path: + tensor_args = ( + src, + first_layer.self_attn.in_proj_weight, + first_layer.self_attn.in_proj_bias, + first_layer.self_attn.out_proj.weight, + first_layer.self_attn.out_proj.bias, + first_layer.norm1.weight, + first_layer.norm1.bias, + first_layer.norm2.weight, + first_layer.norm2.bias, + first_layer.linear1.weight, + first_layer.linear1.bias, + first_layer.linear2.weight, + first_layer.linear2.bias, + ) + + if torch.overrides.has_torch_function(tensor_args): + why_not_sparsity_fast_path = "some Tensor argument has_torch_function" + elif not (src.is_cuda or 'cpu' in str(src.device)): + why_not_sparsity_fast_path = "src is neither CUDA nor CPU" + elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): + why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad") + + if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None): + convert_to_nested = True + output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False) + src_key_padding_mask_for_layers = None + + for mod in self.layers: + output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers) + + if convert_to_nested: + output = output.to_padded_tensor(0.) + + if self.norm is not None: + output = self.norm(output) + + return output + + class TransformerEncoderLayer(nn.Module): """ Modified from torch.nn.TransformerEncoderLayer.