mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
replace torch.div() with <<
This commit is contained in:
parent
12c176c443
commit
5cfdbd3699
@ -1252,6 +1252,10 @@ class EmformerEncoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert int(math.log(chunk_length, 2)) == math.log(
|
||||
chunk_length, 2
|
||||
), "chunk_length should be a power of 2."
|
||||
|
||||
self.use_memory = memory_size > 0
|
||||
self.init_memory_op = nn.AvgPool1d(
|
||||
kernel_size=chunk_length,
|
||||
@ -1580,10 +1584,8 @@ class EmformerEncoder(nn.Module):
|
||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||
memory_mask = (
|
||||
(
|
||||
torch.div(
|
||||
num_processed_frames,
|
||||
self.chunk_length,
|
||||
rounding_mode="floor",
|
||||
(
|
||||
num_processed_frames << int(math.log(self.chunk_length, 2))
|
||||
).view(x.size(1), 1)
|
||||
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||
x.size(1), self.memory_size
|
||||
|
@ -1188,6 +1188,10 @@ class EmformerEncoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert int(math.log(chunk_length, 2)) == math.log(
|
||||
chunk_length, 2
|
||||
), "chunk_length should be a power of 2."
|
||||
|
||||
self.use_memory = memory_size > 0
|
||||
|
||||
self.emformer_layers = nn.ModuleList(
|
||||
@ -1488,10 +1492,8 @@ class EmformerEncoder(nn.Module):
|
||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||
memory_mask = (
|
||||
(
|
||||
torch.div(
|
||||
num_processed_frames,
|
||||
self.chunk_length,
|
||||
rounding_mode="floor",
|
||||
(
|
||||
num_processed_frames << int(math.log(self.chunk_length, 2))
|
||||
).view(x.size(1), 1)
|
||||
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||
x.size(1), self.memory_size
|
||||
|
Loading…
x
Reference in New Issue
Block a user