mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
minor fix
This commit is contained in:
parent
e3e8b1990c
commit
ad68987423
@ -1253,8 +1253,9 @@ class EmformerEncoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
chunk_length - 1 & chunk_length == 0
|
||||
), "chunk_length should be a power of 2."
|
||||
chunk_length - 1
|
||||
) & chunk_length == 0, "chunk_length should be a power of 2."
|
||||
self.shift = int(math.log(chunk_length, 2))
|
||||
|
||||
self.use_memory = memory_size > 0
|
||||
self.init_memory_op = nn.AvgPool1d(
|
||||
@ -1584,9 +1585,7 @@ class EmformerEncoder(nn.Module):
|
||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||
memory_mask = (
|
||||
(
|
||||
(
|
||||
num_processed_frames >> int(math.log(self.chunk_length, 2))
|
||||
).view(x.size(1), 1)
|
||||
(num_processed_frames >> self.shift).view(x.size(1), 1)
|
||||
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||
x.size(1), self.memory_size
|
||||
)
|
||||
|
@ -1189,8 +1189,9 @@ class EmformerEncoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
chunk_length - 1 & chunk_length == 0
|
||||
), "chunk_length should be a power of 2."
|
||||
chunk_length - 1
|
||||
) & chunk_length == 0, "chunk_length should be a power of 2."
|
||||
self.shift = int(math.log(chunk_length, 2))
|
||||
|
||||
self.use_memory = memory_size > 0
|
||||
|
||||
@ -1492,9 +1493,7 @@ class EmformerEncoder(nn.Module):
|
||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||
memory_mask = (
|
||||
(
|
||||
(
|
||||
num_processed_frames >> int(math.log(self.chunk_length, 2))
|
||||
).view(x.size(1), 1)
|
||||
(num_processed_frames >> self.shift).view(x.size(1), 1)
|
||||
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||
x.size(1), self.memory_size
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user