from local
This commit is contained in:
parent
8d43329787
commit
d61e27625f
Binary file not shown.
@ -510,67 +510,11 @@ class Tempformer(EncoderInterface):
|
|||||||
def forward(
|
def forward(
|
||||||
self, x
|
self, x
|
||||||
):
|
):
|
||||||
"""
|
layer_outputs = []
|
||||||
Args:
|
for enum, encoder in enumerate(self.encoder_layers):
|
||||||
x:
|
layer_outputs.append(encoder(x[enum]))
|
||||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
|
||||||
x_lens:
|
|
||||||
A tensor of shape (batch_size,) containing the number of frames in
|
|
||||||
`x` before padding.
|
|
||||||
Returns:
|
|
||||||
Return a tuple containing 2 tensors:
|
|
||||||
- embeddings: its shape is (batch_size, output_seq_len, d_model)
|
|
||||||
- lengths, a tensor of shape (batch_size,) containing the number
|
|
||||||
of frames in `embeddings` before padding.
|
|
||||||
"""
|
|
||||||
x = self.encoder_embed(x)
|
|
||||||
x, pos_emb = self.encoder_pos(x)
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
|
||||||
|
|
||||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
|
||||||
assert x.size(0) == lengths.max().item()
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
|
||||||
|
|
||||||
if self.dynamic_chunk_training:
|
|
||||||
assert (
|
|
||||||
self.causal
|
|
||||||
), "Causal convolution is required for streaming conformer."
|
|
||||||
max_len = x.size(0)
|
|
||||||
chunk_size = torch.randint(1, max_len, (1,)).item()
|
|
||||||
if chunk_size > (max_len * self.short_chunk_threshold):
|
|
||||||
chunk_size = max_len
|
|
||||||
else:
|
|
||||||
chunk_size = chunk_size % self.short_chunk_size + 1
|
|
||||||
|
|
||||||
mask = ~subsequent_chunk_mask(
|
|
||||||
size=x.size(0),
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
num_left_chunks=self.num_left_chunks,
|
|
||||||
device=x.device,
|
|
||||||
)
|
|
||||||
x = self.encoder(
|
|
||||||
x,
|
|
||||||
pos_emb,
|
|
||||||
mask=mask,
|
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
|
||||||
warmup=warmup,
|
|
||||||
) # (T, N, C)
|
|
||||||
else:
|
|
||||||
x, layer_outputs = self.encoder(
|
|
||||||
x,
|
|
||||||
pos_emb,
|
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
|
||||||
warmup=warmup,
|
|
||||||
) # (T, N, C)
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
|
||||||
layer_outputs = [x.permute(1, 0, 2) for x in layer_outputs]
|
|
||||||
|
|
||||||
if get_layer_output:
|
|
||||||
return x, lengths, layer_outputs
|
|
||||||
else:
|
|
||||||
return x, lengths
|
|
||||||
|
|
||||||
|
return layer_outputs
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoderLayer(nn.Module):
|
class ConformerEncoderLayer(nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user