mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 13:34:20 +00:00
use torch.div
This commit is contained in:
parent
c8adbcce64
commit
f57f1b8a44
@ -19,7 +19,6 @@
|
|||||||
# 2) https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa
|
# 2) https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -1558,12 +1557,6 @@ class EmformerEncoder(nn.Module):
|
|||||||
self.cnn_module_kernel - 1,
|
self.cnn_module_kernel - 1,
|
||||||
), conv_caches[i].shape
|
), conv_caches[i].shape
|
||||||
|
|
||||||
# assert x.size(0) == self.chunk_length + self.right_context_length, (
|
|
||||||
# "Per configured chunk_length and right_context_length, "
|
|
||||||
# f"expected size of {self.chunk_length + self.right_context_length} "
|
|
||||||
# f"for dimension 1 of x, but got {x.size(0)}."
|
|
||||||
# )
|
|
||||||
|
|
||||||
right_context = x[-self.right_context_length :]
|
right_context = x[-self.right_context_length :]
|
||||||
utterance = x[: -self.right_context_length]
|
utterance = x[: -self.right_context_length]
|
||||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
@ -1576,7 +1569,9 @@ class EmformerEncoder(nn.Module):
|
|||||||
# calcualte padding mask to mask out initial zero caches
|
# calcualte padding mask to mask out initial zero caches
|
||||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||||
memory_mask = (
|
memory_mask = (
|
||||||
(num_processed_frames // self.chunk_length).view(x.size(1), 1)
|
torch.div(
|
||||||
|
num_processed_frames, self.chunk_length, rounding_mode="floor"
|
||||||
|
).view(x.size(1), 1)
|
||||||
<= torch.arange(self.memory_size, device=x.device).expand(
|
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||||
x.size(1), self.memory_size
|
x.size(1), self.memory_size
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user