mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
fix bug about memory mask when memory_size==0
This commit is contained in:
parent
9c37c16326
commit
10662c5c38
3
.flake8
3
.flake8
@ -9,8 +9,7 @@ per-file-ignores =
|
|||||||
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
|
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
|
||||||
egs/*/ASR/*/optim.py: E501,
|
egs/*/ASR/*/optim.py: E501,
|
||||||
egs/*/ASR/*/scaling.py: E501,
|
egs/*/ASR/*/scaling.py: E501,
|
||||||
egs/librispeech/ASR/conv_emformer_transducer_stateless/*.py: E501, E203
|
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
|
||||||
egs/librispeech/ASR/conv_emformer_transducer_stateless2/*.py: E501, E203
|
|
||||||
|
|
||||||
# invalid escape sequence (cause by tex formular), W605
|
# invalid escape sequence (cause by tex formular), W605
|
||||||
icefall/utils.py: E501, W605
|
icefall/utils.py: E501, W605
|
||||||
|
@ -1579,13 +1579,19 @@ class EmformerEncoder(nn.Module):
|
|||||||
# calcualte padding mask to mask out initial zero caches
|
# calcualte padding mask to mask out initial zero caches
|
||||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||||
memory_mask = (
|
memory_mask = (
|
||||||
|
(
|
||||||
torch.div(
|
torch.div(
|
||||||
num_processed_frames, self.chunk_length, rounding_mode="floor"
|
num_processed_frames,
|
||||||
|
self.chunk_length,
|
||||||
|
rounding_mode="floor",
|
||||||
).view(x.size(1), 1)
|
).view(x.size(1), 1)
|
||||||
<= torch.arange(self.memory_size, device=x.device).expand(
|
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||||
x.size(1), self.memory_size
|
x.size(1), self.memory_size
|
||||||
)
|
)
|
||||||
).flip(1)
|
).flip(1)
|
||||||
|
if self.use_memory
|
||||||
|
else torch.empty(0).to(dtype=torch.bool, device=x.device)
|
||||||
|
)
|
||||||
left_context_mask = (
|
left_context_mask = (
|
||||||
num_processed_frames.view(x.size(1), 1)
|
num_processed_frames.view(x.size(1), 1)
|
||||||
<= torch.arange(self.left_context_length, device=x.device).expand(
|
<= torch.arange(self.left_context_length, device=x.device).expand(
|
||||||
|
@ -1388,7 +1388,11 @@ class EmformerEncoder(nn.Module):
|
|||||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
attention_mask = self._gen_attention_mask(utterance)
|
attention_mask = self._gen_attention_mask(utterance)
|
||||||
|
|
||||||
M = right_context.size(0) // self.right_context_length - 1
|
M = (
|
||||||
|
right_context.size(0) // self.right_context_length - 1
|
||||||
|
if self.use_memory
|
||||||
|
else 0
|
||||||
|
)
|
||||||
padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
|
padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths)
|
||||||
|
|
||||||
output = utterance
|
output = utterance
|
||||||
@ -1480,13 +1484,19 @@ class EmformerEncoder(nn.Module):
|
|||||||
# calcualte padding mask to mask out initial zero caches
|
# calcualte padding mask to mask out initial zero caches
|
||||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||||
memory_mask = (
|
memory_mask = (
|
||||||
|
(
|
||||||
torch.div(
|
torch.div(
|
||||||
num_processed_frames, self.chunk_length, rounding_mode="floor"
|
num_processed_frames,
|
||||||
|
self.chunk_length,
|
||||||
|
rounding_mode="floor",
|
||||||
).view(x.size(1), 1)
|
).view(x.size(1), 1)
|
||||||
<= torch.arange(self.memory_size, device=x.device).expand(
|
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||||
x.size(1), self.memory_size
|
x.size(1), self.memory_size
|
||||||
)
|
)
|
||||||
).flip(1)
|
).flip(1)
|
||||||
|
if self.use_memory
|
||||||
|
else torch.empty(0).to(dtype=torch.bool, device=x.device)
|
||||||
|
)
|
||||||
left_context_mask = (
|
left_context_mask = (
|
||||||
num_processed_frames.view(x.size(1), 1)
|
num_processed_frames.view(x.size(1), 1)
|
||||||
<= torch.arange(self.left_context_length, device=x.device).expand(
|
<= torch.arange(self.left_context_length, device=x.device).expand(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user