minor fix

This commit is contained in:
yaozengwei 2022-07-08 17:28:45 +08:00
parent e3e8b1990c
commit ad68987423
2 changed files with 8 additions and 10 deletions

View File

@ -1253,8 +1253,9 @@ class EmformerEncoder(nn.Module):
super().__init__()
assert (
chunk_length - 1 & chunk_length == 0
), "chunk_length should be a power of 2."
chunk_length - 1
) & chunk_length == 0, "chunk_length should be a power of 2."
self.shift = int(math.log(chunk_length, 2))
self.use_memory = memory_size > 0
self.init_memory_op = nn.AvgPool1d(
@ -1584,9 +1585,7 @@ class EmformerEncoder(nn.Module):
chunk_mask = make_pad_mask(output_lengths).to(x.device)
memory_mask = (
(
(
num_processed_frames >> int(math.log(self.chunk_length, 2))
).view(x.size(1), 1)
(num_processed_frames >> self.shift).view(x.size(1), 1)
<= torch.arange(self.memory_size, device=x.device).expand(
x.size(1), self.memory_size
)

View File

@ -1189,8 +1189,9 @@ class EmformerEncoder(nn.Module):
super().__init__()
assert (
chunk_length - 1 & chunk_length == 0
), "chunk_length should be a power of 2."
chunk_length - 1
) & chunk_length == 0, "chunk_length should be a power of 2."
self.shift = int(math.log(chunk_length, 2))
self.use_memory = memory_size > 0
@ -1492,9 +1493,7 @@ class EmformerEncoder(nn.Module):
chunk_mask = make_pad_mask(output_lengths).to(x.device)
memory_mask = (
(
(
num_processed_frames >> int(math.log(self.chunk_length, 2))
).view(x.size(1), 1)
(num_processed_frames >> self.shift).view(x.size(1), 1)
<= torch.arange(self.memory_size, device=x.device).expand(
x.size(1), self.memory_size
)