mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
8d43329787
commit
d61e27625f
Binary file not shown.
@ -510,67 +510,11 @@ class Tempformer(EncoderInterface):
|
||||
def forward(
|
||||
self, x
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- embeddings: its shape is (batch_size, output_seq_len, d_model)
|
||||
- lengths, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
assert x.size(0) == lengths.max().item()
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
if self.dynamic_chunk_training:
|
||||
assert (
|
||||
self.causal
|
||||
), "Causal convolution is required for streaming conformer."
|
||||
max_len = x.size(0)
|
||||
chunk_size = torch.randint(1, max_len, (1,)).item()
|
||||
if chunk_size > (max_len * self.short_chunk_threshold):
|
||||
chunk_size = max_len
|
||||
else:
|
||||
chunk_size = chunk_size % self.short_chunk_size + 1
|
||||
|
||||
mask = ~subsequent_chunk_mask(
|
||||
size=x.size(0),
|
||||
chunk_size=chunk_size,
|
||||
num_left_chunks=self.num_left_chunks,
|
||||
device=x.device,
|
||||
)
|
||||
x = self.encoder(
|
||||
x,
|
||||
pos_emb,
|
||||
mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
) # (T, N, C)
|
||||
else:
|
||||
x, layer_outputs = self.encoder(
|
||||
x,
|
||||
pos_emb,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
) # (T, N, C)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
layer_outputs = [x.permute(1, 0, 2) for x in layer_outputs]
|
||||
|
||||
if get_layer_output:
|
||||
return x, lengths, layer_outputs
|
||||
else:
|
||||
return x, lengths
|
||||
layer_outputs = []
|
||||
for enum, encoder in enumerate(self.encoder_layers):
|
||||
layer_outputs.append(encoder(x[enum]))
|
||||
|
||||
return layer_outputs
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user