mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 13:34:20 +00:00
use torch.div
This commit is contained in:
parent
c8adbcce64
commit
f57f1b8a44
@ -19,7 +19,6 @@
|
||||
# 2) https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -1558,12 +1557,6 @@ class EmformerEncoder(nn.Module):
|
||||
self.cnn_module_kernel - 1,
|
||||
), conv_caches[i].shape
|
||||
|
||||
# assert x.size(0) == self.chunk_length + self.right_context_length, (
|
||||
# "Per configured chunk_length and right_context_length, "
|
||||
# f"expected size of {self.chunk_length + self.right_context_length} "
|
||||
# f"for dimension 1 of x, but got {x.size(0)}."
|
||||
# )
|
||||
|
||||
right_context = x[-self.right_context_length :]
|
||||
utterance = x[: -self.right_context_length]
|
||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||
@ -1576,7 +1569,9 @@ class EmformerEncoder(nn.Module):
|
||||
# calcualte padding mask to mask out initial zero caches
|
||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||
memory_mask = (
|
||||
(num_processed_frames // self.chunk_length).view(x.size(1), 1)
|
||||
torch.div(
|
||||
num_processed_frames, self.chunk_length, rounding_mode="floor"
|
||||
).view(x.size(1), 1)
|
||||
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||
x.size(1), self.memory_size
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user