replace torch.div() with <<

This commit is contained in:
yaozengwei 2022-07-08 16:41:46 +08:00
parent 12c176c443
commit 5cfdbd3699
2 changed files with 12 additions and 8 deletions

View File

@ -1252,6 +1252,10 @@ class EmformerEncoder(nn.Module):
): ):
super().__init__() super().__init__()
assert int(math.log(chunk_length, 2)) == math.log(
chunk_length, 2
), "chunk_length should be a power of 2."
self.use_memory = memory_size > 0 self.use_memory = memory_size > 0
self.init_memory_op = nn.AvgPool1d( self.init_memory_op = nn.AvgPool1d(
kernel_size=chunk_length, kernel_size=chunk_length,
@ -1580,10 +1584,8 @@ class EmformerEncoder(nn.Module):
chunk_mask = make_pad_mask(output_lengths).to(x.device) chunk_mask = make_pad_mask(output_lengths).to(x.device)
memory_mask = ( memory_mask = (
( (
torch.div( (
num_processed_frames, num_processed_frames << int(math.log(self.chunk_length, 2))
self.chunk_length,
rounding_mode="floor",
).view(x.size(1), 1) ).view(x.size(1), 1)
<= torch.arange(self.memory_size, device=x.device).expand( <= torch.arange(self.memory_size, device=x.device).expand(
x.size(1), self.memory_size x.size(1), self.memory_size

View File

@ -1188,6 +1188,10 @@ class EmformerEncoder(nn.Module):
): ):
super().__init__() super().__init__()
assert int(math.log(chunk_length, 2)) == math.log(
chunk_length, 2
), "chunk_length should be a power of 2."
self.use_memory = memory_size > 0 self.use_memory = memory_size > 0
self.emformer_layers = nn.ModuleList( self.emformer_layers = nn.ModuleList(
@ -1488,10 +1492,8 @@ class EmformerEncoder(nn.Module):
chunk_mask = make_pad_mask(output_lengths).to(x.device) chunk_mask = make_pad_mask(output_lengths).to(x.device)
memory_mask = ( memory_mask = (
( (
torch.div( (
num_processed_frames, num_processed_frames << int(math.log(self.chunk_length, 2))
self.chunk_length,
rounding_mode="floor",
).view(x.size(1), 1) ).view(x.size(1), 1)
<= torch.arange(self.memory_size, device=x.device).expand( <= torch.arange(self.memory_size, device=x.device).expand(
x.size(1), self.memory_size x.size(1), self.memory_size