use torch.div

This commit is contained in:
yaozengwei 2022-06-11 21:21:49 +08:00
parent c8adbcce64
commit f57f1b8a44

View File

@ -19,7 +19,6 @@
# 2) https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa # 2) https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa
import math import math
import warnings
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
@ -1558,12 +1557,6 @@ class EmformerEncoder(nn.Module):
self.cnn_module_kernel - 1, self.cnn_module_kernel - 1,
), conv_caches[i].shape ), conv_caches[i].shape
# assert x.size(0) == self.chunk_length + self.right_context_length, (
# "Per configured chunk_length and right_context_length, "
# f"expected size of {self.chunk_length + self.right_context_length} "
# f"for dimension 1 of x, but got {x.size(0)}."
# )
right_context = x[-self.right_context_length :] right_context = x[-self.right_context_length :]
utterance = x[: -self.right_context_length] utterance = x[: -self.right_context_length]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0) output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
@ -1576,7 +1569,9 @@ class EmformerEncoder(nn.Module):
# calcualte padding mask to mask out initial zero caches # calcualte padding mask to mask out initial zero caches
chunk_mask = make_pad_mask(output_lengths).to(x.device) chunk_mask = make_pad_mask(output_lengths).to(x.device)
memory_mask = ( memory_mask = (
(num_processed_frames // self.chunk_length).view(x.size(1), 1) torch.div(
num_processed_frames, self.chunk_length, rounding_mode="floor"
).view(x.size(1), 1)
<= torch.arange(self.memory_size, device=x.device).expand( <= torch.arange(self.memory_size, device=x.device).expand(
x.size(1), self.memory_size x.size(1), self.memory_size
) )