from local

This commit is contained in:
dohe0342 2023-01-09 19:43:22 +09:00
parent 8d43329787
commit d61e27625f
2 changed files with 4 additions and 60 deletions

View File

@ -510,67 +510,11 @@ class Tempformer(EncoderInterface):
def forward(
self, x
):
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
Returns:
Return a tuple containing 2 tensors:
- embeddings: its shape is (batch_size, output_seq_len, d_model)
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
src_key_padding_mask = make_pad_mask(lengths)
if self.dynamic_chunk_training:
assert (
self.causal
), "Causal convolution is required for streaming conformer."
max_len = x.size(0)
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = chunk_size % self.short_chunk_size + 1
mask = ~subsequent_chunk_mask(
size=x.size(0),
chunk_size=chunk_size,
num_left_chunks=self.num_left_chunks,
device=x.device,
)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
else:
x, layer_outputs = self.encoder(
x,
pos_emb,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
layer_outputs = [x.permute(1, 0, 2) for x in layer_outputs]
if get_layer_output:
return x, lengths, layer_outputs
else:
return x, lengths
layer_outputs = []
for enum, encoder in enumerate(self.encoder_layers):
layer_outputs.append(encoder(x[enum]))
return layer_outputs
class ConformerEncoderLayer(nn.Module):