Bug fix regarding chunk-size reshaping

This commit is contained in:
Daniel Povey 2023-05-16 17:29:34 +08:00
parent 5f5df4367d
commit cf93d1f129

View File

@ -682,13 +682,12 @@ class SubformerEncoder(nn.Module):
def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor: def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor:
""" """
Reshape embeddings 'src' to have a different chunk size (in frames) Reshape embeddings 'src' to have a different chunk size (in frames) by
changing the batch size.
""" """
(seq_len, batch_size, num_channels) = src.shape (seq_len, batch_size, num_channels) = src.shape
num_chunks = src.shape[0] // chunk_size src = src.transpose(0, 1).contiguous().reshape(-1, chunk_size, num_channels)
src = src.reshape(num_chunks, chunk_size, batch_size, num_channels) return src.transpose(0, 1).contiguous()
return src.contiguous().reshape(chunk_size, num_chunks * batch_size,
num_channels)
def _attn_offset_to_chunk_size(self, attn_offset: Tensor, batch_size: int, chunk_size: int) -> Tensor: def _attn_offset_to_chunk_size(self, attn_offset: Tensor, batch_size: int, chunk_size: int) -> Tensor: