mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-12 18:44:20 +00:00
Fix style
This commit is contained in:
parent
b23db42486
commit
fb3f3d2526
@ -15,7 +15,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
@ -172,22 +171,23 @@ class Conformer(EncoderInterface):
|
|||||||
chunk_size = chunk_size % self.short_chunk_size + 1
|
chunk_size = chunk_size % self.short_chunk_size + 1
|
||||||
|
|
||||||
mask = ~subsequent_chunk_mask(
|
mask = ~subsequent_chunk_mask(
|
||||||
size=x.size(0), chunk_size=chunk_size,
|
size=x.size(0),
|
||||||
num_left_chunks=self.num_left_chunks, device=x.device
|
chunk_size=chunk_size,
|
||||||
|
num_left_chunks=self.num_left_chunks,
|
||||||
|
device=x.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
x, _ = self.encoder(
|
x, _ = self.encoder(
|
||||||
x, pos_emb,
|
x,
|
||||||
|
pos_emb,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
) # (T, N, C)
|
) # (T, N, C)
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
|
|
||||||
def streaming_forward(
|
def streaming_forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -250,9 +250,16 @@ class Conformer(EncoderInterface):
|
|||||||
), "Require cache when sending data in streaming mode"
|
), "Require cache when sending data in streaming mode"
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
len(states) == 2 and
|
len(states) == 2
|
||||||
states[0].shape == (self.encoder_layers, left_context, x.size(0), self.d_model) and
|
and states[0].shape
|
||||||
states[1].shape == (self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)
|
== (self.encoder_layers, left_context, x.size(0), self.d_model)
|
||||||
|
and states[1].shape
|
||||||
|
== (
|
||||||
|
self.encoder_layers,
|
||||||
|
self.cnn_module_kernel - 1,
|
||||||
|
x.size(0),
|
||||||
|
self.d_model,
|
||||||
|
)
|
||||||
), f"""The length of states MUST be equal to 2, and the shape of
|
), f"""The length of states MUST be equal to 2, and the shape of
|
||||||
first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
|
first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
|
||||||
given {states[0].shape}. the shape of second element should be
|
given {states[0].shape}. the shape of second element should be
|
||||||
@ -293,7 +300,7 @@ class Conformer(EncoderInterface):
|
|||||||
size=x.size(0),
|
size=x.size(0),
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
num_left_chunks=num_left_chunks,
|
num_left_chunks=num_left_chunks,
|
||||||
device=x.device
|
device=x.device,
|
||||||
)
|
)
|
||||||
x = self.encoder(
|
x = self.encoder(
|
||||||
x,
|
x,
|
||||||
@ -366,9 +373,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(
|
self.conv_module = ConvolutionModule(
|
||||||
d_model,
|
d_model, cnn_module_kernel, causal=causal
|
||||||
cnn_module_kernel,
|
|
||||||
causal=causal
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
@ -546,7 +551,11 @@ class ConformerEncoder(nn.Module):
|
|||||||
assert left_context >= 0
|
assert left_context >= 0
|
||||||
|
|
||||||
for layer_index, mod in enumerate(self.layers):
|
for layer_index, mod in enumerate(self.layers):
|
||||||
cache = None if states is None else [states[0][layer_index], states[1][layer_index]]
|
cache = (
|
||||||
|
None
|
||||||
|
if states is None
|
||||||
|
else [states[0][layer_index], states[1][layer_index]]
|
||||||
|
)
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
@ -623,10 +632,10 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
context: int = 0
|
context: int = 0,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Add positional encoding.
|
"""Add positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1079,16 +1088,23 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
# the whole column of `attn_output_weights` will be `-inf`
|
# the whole column of `attn_output_weights` will be `-inf`
|
||||||
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
|
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
|
||||||
# positions to avoid invalid loss value below.
|
# positions to avoid invalid loss value below.
|
||||||
if attn_mask is not None and attn_mask.dtype == torch.bool and \
|
if (
|
||||||
key_padding_mask is not None:
|
attn_mask is not None
|
||||||
combined_mask = attn_mask.unsqueeze(
|
and attn_mask.dtype == torch.bool
|
||||||
0) | key_padding_mask.unsqueeze(1).unsqueeze(2)
|
and key_padding_mask is not None
|
||||||
|
):
|
||||||
|
combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
|
||||||
|
1
|
||||||
|
).unsqueeze(2)
|
||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
bsz, num_heads, tgt_len, src_len)
|
bsz, num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
attn_output_weights = attn_output_weights.masked_fill(
|
attn_output_weights = attn_output_weights.masked_fill(
|
||||||
combined_mask, 0.0)
|
combined_mask, 0.0
|
||||||
|
)
|
||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
bsz * num_heads, tgt_len, src_len)
|
bsz * num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
|
|
||||||
attn_output_weights = nn.functional.dropout(
|
attn_output_weights = nn.functional.dropout(
|
||||||
attn_output_weights, p=dropout_p, training=training
|
attn_output_weights, p=dropout_p, training=training
|
||||||
@ -1131,7 +1147,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
channels: int,
|
channels: int,
|
||||||
kernel_size: int,
|
kernel_size: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
causal: bool = False
|
causal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct an ConvolutionModule object."""
|
"""Construct an ConvolutionModule object."""
|
||||||
super(ConvolutionModule, self).__init__()
|
super(ConvolutionModule, self).__init__()
|
||||||
@ -1197,10 +1213,10 @@ class ConvolutionModule(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
cache: Optional[Tensor] = None
|
cache: Optional[Tensor] = None,
|
||||||
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
||||||
"""Compute convolution module.
|
"""Compute convolution module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1231,10 +1247,12 @@ class ConvolutionModule(nn.Module):
|
|||||||
# manualy padding self.lorder zeros to the left
|
# manualy padding self.lorder zeros to the left
|
||||||
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||||
else:
|
else:
|
||||||
assert not self.training, "Cache should be None in training time"
|
assert (
|
||||||
|
not self.training
|
||||||
|
), "Cache should be None in training time"
|
||||||
assert cache.size(0) == self.lorder
|
assert cache.size(0) == self.lorder
|
||||||
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
|
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
|
||||||
cache = x.permute(2, 0, 1)[-self.lorder:,...]
|
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
|
|
||||||
x = self.deriv_balancer2(x)
|
x = self.deriv_balancer2(x)
|
||||||
@ -1242,7 +1260,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
return x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
|
return (
|
||||||
|
x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
class Conv2dSubsampling(nn.Module):
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -177,9 +176,9 @@ class Transducer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
offset = (boundary[:, 3] - 1) / 2
|
offset = (boundary[:, 3] - 1) / 2
|
||||||
total_syms = torch.sum(boundary[:, 2])
|
total_syms = torch.sum(boundary[:, 2])
|
||||||
offset = torch.arange(
|
offset = torch.arange(T0, device=px_grad.device).reshape(
|
||||||
T0, device=px_grad.device
|
1, 1, T0
|
||||||
).reshape(1, 1, T0) - offset.reshape(B, 1, 1)
|
) - offset.reshape(B, 1, 1)
|
||||||
sym_delay = px_grad * offset
|
sym_delay = px_grad * offset
|
||||||
sym_delay = torch.sum(sym_delay) / total_syms
|
sym_delay = torch.sum(sym_delay) / total_syms
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user