mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
replace torch.div() with <<
This commit is contained in:
parent
12c176c443
commit
5cfdbd3699
@ -1252,6 +1252,10 @@ class EmformerEncoder(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
assert int(math.log(chunk_length, 2)) == math.log(
|
||||||
|
chunk_length, 2
|
||||||
|
), "chunk_length should be a power of 2."
|
||||||
|
|
||||||
self.use_memory = memory_size > 0
|
self.use_memory = memory_size > 0
|
||||||
self.init_memory_op = nn.AvgPool1d(
|
self.init_memory_op = nn.AvgPool1d(
|
||||||
kernel_size=chunk_length,
|
kernel_size=chunk_length,
|
||||||
@ -1580,10 +1584,8 @@ class EmformerEncoder(nn.Module):
|
|||||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||||
memory_mask = (
|
memory_mask = (
|
||||||
(
|
(
|
||||||
torch.div(
|
(
|
||||||
num_processed_frames,
|
num_processed_frames << int(math.log(self.chunk_length, 2))
|
||||||
self.chunk_length,
|
|
||||||
rounding_mode="floor",
|
|
||||||
).view(x.size(1), 1)
|
).view(x.size(1), 1)
|
||||||
<= torch.arange(self.memory_size, device=x.device).expand(
|
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||||
x.size(1), self.memory_size
|
x.size(1), self.memory_size
|
||||||
|
@ -1188,6 +1188,10 @@ class EmformerEncoder(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
assert int(math.log(chunk_length, 2)) == math.log(
|
||||||
|
chunk_length, 2
|
||||||
|
), "chunk_length should be a power of 2."
|
||||||
|
|
||||||
self.use_memory = memory_size > 0
|
self.use_memory = memory_size > 0
|
||||||
|
|
||||||
self.emformer_layers = nn.ModuleList(
|
self.emformer_layers = nn.ModuleList(
|
||||||
@ -1488,10 +1492,8 @@ class EmformerEncoder(nn.Module):
|
|||||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||||
memory_mask = (
|
memory_mask = (
|
||||||
(
|
(
|
||||||
torch.div(
|
(
|
||||||
num_processed_frames,
|
num_processed_frames << int(math.log(self.chunk_length, 2))
|
||||||
self.chunk_length,
|
|
||||||
rounding_mode="floor",
|
|
||||||
).view(x.size(1), 1)
|
).view(x.size(1), 1)
|
||||||
<= torch.arange(self.memory_size, device=x.device).expand(
|
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||||
x.size(1), self.memory_size
|
x.size(1), self.memory_size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user