mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Fix style
This commit is contained in:
parent
fb3f3d2526
commit
aecaecfb17
@ -169,8 +169,10 @@ class Conformer(Transformer):
|
||||
chunk_size = chunk_size % self.short_chunk_size + 1
|
||||
|
||||
mask = ~subsequent_chunk_mask(
|
||||
size=x.size(0), chunk_size=chunk_size,
|
||||
num_left_chunks=self.num_left_chunks, device=x.device
|
||||
size=x.size(0),
|
||||
chunk_size=chunk_size,
|
||||
num_left_chunks=self.num_left_chunks,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
x, _ = self.encoder(
|
||||
@ -185,7 +187,6 @@ class Conformer(Transformer):
|
||||
|
||||
return logits, lengths
|
||||
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -243,9 +244,16 @@ class Conformer(Transformer):
|
||||
), "Require cache when sending data in streaming mode"
|
||||
|
||||
assert (
|
||||
len(states) == 2 and
|
||||
states[0].shape == (self.encoder_layers, left_context, x.size(0), self.d_model) and
|
||||
states[1].shape == (self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)
|
||||
len(states) == 2
|
||||
and states[0].shape
|
||||
== (self.encoder_layers, left_context, x.size(0), self.d_model)
|
||||
and states[1].shape
|
||||
== (
|
||||
self.encoder_layers,
|
||||
self.cnn_module_kernel - 1,
|
||||
x.size(0),
|
||||
self.d_model,
|
||||
)
|
||||
), f"""The length of states MUST be equal to 2, and the shape of
|
||||
first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
|
||||
given {states[0].shape}. the shape of second element should be
|
||||
@ -285,7 +293,7 @@ class Conformer(Transformer):
|
||||
size=x.size(0),
|
||||
chunk_size=chunk_size,
|
||||
num_left_chunks=num_left_chunks,
|
||||
device=x.device
|
||||
device=x.device,
|
||||
)
|
||||
x = self.encoder(
|
||||
x,
|
||||
@ -544,7 +552,11 @@ class ConformerEncoder(nn.Module):
|
||||
assert left_context >= 0
|
||||
|
||||
for layer_index, mod in enumerate(self.layers):
|
||||
cache = None if states is None else [states[0][layer_index], states[1][layer_index]]
|
||||
cache = (
|
||||
None
|
||||
if states is None
|
||||
else [states[0][layer_index], states[1][layer_index]]
|
||||
)
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
@ -621,10 +633,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: int = 0
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
self, x: torch.Tensor, context: int = 0
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
@ -1073,16 +1083,23 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# the whole column of `attn_output_weights` will be `-inf`
|
||||
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
|
||||
# positions to avoid invalid loss value below.
|
||||
if attn_mask is not None and attn_mask.dtype == torch.bool and \
|
||||
key_padding_mask is not None:
|
||||
combined_mask = attn_mask.unsqueeze(
|
||||
0) | key_padding_mask.unsqueeze(1).unsqueeze(2)
|
||||
if (
|
||||
attn_mask is not None
|
||||
and attn_mask.dtype == torch.bool
|
||||
and key_padding_mask is not None
|
||||
):
|
||||
combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
|
||||
1
|
||||
).unsqueeze(2)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len)
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
combined_mask, 0.0)
|
||||
combined_mask, 0.0
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, src_len)
|
||||
bsz * num_heads, tgt_len, src_len
|
||||
)
|
||||
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=dropout_p, training=training
|
||||
@ -1125,7 +1142,7 @@ class ConvolutionModule(nn.Module):
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
bias: bool = True,
|
||||
causal: bool = False
|
||||
causal: bool = False,
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
@ -1168,10 +1185,8 @@ class ConvolutionModule(nn.Module):
|
||||
self.activation = Swish()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
cache: Optional[Tensor] = None
|
||||
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
||||
self, x: Tensor, cache: Optional[Tensor] = None
|
||||
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
@ -1195,10 +1210,12 @@ class ConvolutionModule(nn.Module):
|
||||
# manualy padding self.lorder zeros to the left
|
||||
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||
else:
|
||||
assert not self.training, "Cache should be None in training time"
|
||||
assert (
|
||||
not self.training
|
||||
), "Cache should be None in training time"
|
||||
assert cache.size(0) == self.lorder
|
||||
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
|
||||
cache = x.permute(2, 0, 1)[-self.lorder:,...]
|
||||
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
|
||||
|
||||
x = self.depthwise_conv(x)
|
||||
# x is (batch, channels, time)
|
||||
@ -1210,7 +1227,9 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
return x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
|
||||
return (
|
||||
x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
|
||||
)
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user