mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
minor updates
This commit is contained in:
parent
ceab22674f
commit
9cd95d88b0
@ -646,12 +646,7 @@ class EmformerAttention(nn.Module):
|
||||
- output of right context and utterance, with shape (R + U, B, D).
|
||||
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
|
||||
"""
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
_,
|
||||
_,
|
||||
) = self._forward_impl(
|
||||
(output_right_context_utterance, output_memory, _, _,) = self._forward_impl(
|
||||
utterance,
|
||||
right_context,
|
||||
summary,
|
||||
@ -1115,11 +1110,7 @@ class EmformerEncoderLayer(nn.Module):
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
# emformer attention module
|
||||
(
|
||||
src_att,
|
||||
output_memory,
|
||||
attn_cache,
|
||||
) = self._apply_attention_module_infer(
|
||||
(src_att, output_memory, attn_cache,) = self._apply_attention_module_infer(
|
||||
src, R, memory, attn_cache, padding_mask=padding_mask
|
||||
)
|
||||
src = src + self.dropout(src_att)
|
||||
|
@ -927,11 +927,7 @@ class EmformerEncoderLayer(nn.Module):
|
||||
]
|
||||
else:
|
||||
memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
|
||||
(
|
||||
output_right_context_utterance,
|
||||
next_key,
|
||||
next_val,
|
||||
) = self.attention.infer(
|
||||
(output_right_context_utterance, next_key, next_val,) = self.attention.infer(
|
||||
utterance=utterance,
|
||||
right_context=right_context,
|
||||
memory=pre_memory,
|
||||
|
@ -1502,11 +1502,7 @@ class EmformerEncoder(nn.Module):
|
||||
end = start + 4
|
||||
cache = states[start:end]
|
||||
|
||||
(
|
||||
output,
|
||||
right_context,
|
||||
output_cache,
|
||||
) = layer.infer(
|
||||
(output, right_context, output_cache,) = layer.infer(
|
||||
output,
|
||||
right_context,
|
||||
padding_mask=None,
|
||||
|
@ -374,11 +374,7 @@ def streaming_forward(
|
||||
Returns encoder outputs, output lengths, and updated states.
|
||||
"""
|
||||
cached_embed_left_pad = states[-2]
|
||||
(
|
||||
x,
|
||||
x_lens,
|
||||
new_cached_embed_left_pad,
|
||||
) = model.encoder_embed.streaming_forward(
|
||||
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
cached_left_pad=cached_embed_left_pad,
|
||||
|
@ -20,6 +20,7 @@ kaldialign==0.7.1
|
||||
sentencepiece==0.1.96
|
||||
tensorboard==2.8.0
|
||||
typeguard==2.13.3
|
||||
black==22.3.0
|
||||
multi_quantization
|
||||
|
||||
onnx
|
||||
|
@ -5,3 +5,4 @@ sentencepiece>=0.1.96
|
||||
tensorboard
|
||||
typeguard
|
||||
dill
|
||||
black==22.3.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user