fix bug in ConvolutionModule

This commit is contained in:
yaozengwei 2022-05-20 22:28:41 +08:00
parent 19a8700301
commit 3806d2db19

View File

@ -419,7 +419,8 @@ class ConvolutionModule(nn.Module):
assert cache.shape == (B, D, self.cache_size), cache.shape
x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R)
# update cache
new_cache = x[:, :, -R - self.cache_size : -R]
x_length = x.size(2)
new_cache = x[:, :, x_length - R - self.cache_size : x_length - R]
# 1-D depth-wise conv
x = self.depthwise_conv(x) # (B, D, U + R)