Fix style

This commit is contained in:
pkufool 2022-05-29 07:35:52 +08:00
parent fb3f3d2526
commit aecaecfb17

View File

@ -169,8 +169,10 @@ class Conformer(Transformer):
chunk_size = chunk_size % self.short_chunk_size + 1 chunk_size = chunk_size % self.short_chunk_size + 1
mask = ~subsequent_chunk_mask( mask = ~subsequent_chunk_mask(
size=x.size(0), chunk_size=chunk_size, size=x.size(0),
num_left_chunks=self.num_left_chunks, device=x.device chunk_size=chunk_size,
num_left_chunks=self.num_left_chunks,
device=x.device,
) )
x, _ = self.encoder( x, _ = self.encoder(
@ -185,7 +187,6 @@ class Conformer(Transformer):
return logits, lengths return logits, lengths
def streaming_forward( def streaming_forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -243,9 +244,16 @@ class Conformer(Transformer):
), "Require cache when sending data in streaming mode" ), "Require cache when sending data in streaming mode"
assert ( assert (
len(states) == 2 and len(states) == 2
states[0].shape == (self.encoder_layers, left_context, x.size(0), self.d_model) and and states[0].shape
states[1].shape == (self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model) == (self.encoder_layers, left_context, x.size(0), self.d_model)
and states[1].shape
== (
self.encoder_layers,
self.cnn_module_kernel - 1,
x.size(0),
self.d_model,
)
), f"""The length of states MUST be equal to 2, and the shape of ), f"""The length of states MUST be equal to 2, and the shape of
first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)}, first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
given {states[0].shape}. the shape of second element should be given {states[0].shape}. the shape of second element should be
@ -285,7 +293,7 @@ class Conformer(Transformer):
size=x.size(0), size=x.size(0),
chunk_size=chunk_size, chunk_size=chunk_size,
num_left_chunks=num_left_chunks, num_left_chunks=num_left_chunks,
device=x.device device=x.device,
) )
x = self.encoder( x = self.encoder(
x, x,
@ -544,7 +552,11 @@ class ConformerEncoder(nn.Module):
assert left_context >= 0 assert left_context >= 0
for layer_index, mod in enumerate(self.layers): for layer_index, mod in enumerate(self.layers):
cache = None if states is None else [states[0][layer_index], states[1][layer_index]] cache = (
None
if states is None
else [states[0][layer_index], states[1][layer_index]]
)
output = mod( output = mod(
output, output,
pos_emb, pos_emb,
@ -621,9 +633,7 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = pe.to(device=x.device, dtype=x.dtype) self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward( def forward(
self, self, x: torch.Tensor, context: int = 0
x: torch.Tensor,
context: int = 0
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Add positional encoding. """Add positional encoding.
@ -1073,16 +1083,23 @@ class RelPositionMultiheadAttention(nn.Module):
# the whole column of `attn_output_weights` will be `-inf` # the whole column of `attn_output_weights` will be `-inf`
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
# positions to avoid invalid loss value below. # positions to avoid invalid loss value below.
if attn_mask is not None and attn_mask.dtype == torch.bool and \ if (
key_padding_mask is not None: attn_mask is not None
combined_mask = attn_mask.unsqueeze( and attn_mask.dtype == torch.bool
0) | key_padding_mask.unsqueeze(1).unsqueeze(2) and key_padding_mask is not None
):
combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
1
).unsqueeze(2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len) bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill( attn_output_weights = attn_output_weights.masked_fill(
combined_mask, 0.0) combined_mask, 0.0
)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len) bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.dropout( attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training attn_output_weights, p=dropout_p, training=training
@ -1125,7 +1142,7 @@ class ConvolutionModule(nn.Module):
channels: int, channels: int,
kernel_size: int, kernel_size: int,
bias: bool = True, bias: bool = True,
causal: bool = False causal: bool = False,
) -> None: ) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
@ -1168,9 +1185,7 @@ class ConvolutionModule(nn.Module):
self.activation = Swish() self.activation = Swish()
def forward( def forward(
self, self, x: Tensor, cache: Optional[Tensor] = None
x: Tensor,
cache: Optional[Tensor] = None
) -> Union[Tensor, Tuple[Tensor, Tensor]]: ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Compute convolution module. """Compute convolution module.
@ -1195,10 +1210,12 @@ class ConvolutionModule(nn.Module):
# manualy padding self.lorder zeros to the left # manualy padding self.lorder zeros to the left
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
else: else:
assert not self.training, "Cache should be None in training time" assert (
not self.training
), "Cache should be None in training time"
assert cache.size(0) == self.lorder assert cache.size(0) == self.lorder
x = torch.cat([cache.permute(1, 2, 0), x], dim=2) x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
cache = x.permute(2, 0, 1)[-self.lorder:,...] cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
# x is (batch, channels, time) # x is (batch, channels, time)
@ -1210,7 +1227,9 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache) return (
x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
)
class Swish(torch.nn.Module): class Swish(torch.nn.Module):