minor updates

This commit is contained in:
jinzr 2023-09-24 17:23:24 +08:00
parent ceab22674f
commit 9cd95d88b0
6 changed files with 7 additions and 26 deletions

View File

@ -646,12 +646,7 @@ class EmformerAttention(nn.Module):
- output of right context and utterance, with shape (R + U, B, D). - output of right context and utterance, with shape (R + U, B, D).
- memory output, with shape (M, B, D), where M = S - 1 or M = 0. - memory output, with shape (M, B, D), where M = S - 1 or M = 0.
""" """
( (output_right_context_utterance, output_memory, _, _,) = self._forward_impl(
output_right_context_utterance,
output_memory,
_,
_,
) = self._forward_impl(
utterance, utterance,
right_context, right_context,
summary, summary,
@ -1115,11 +1110,7 @@ class EmformerEncoderLayer(nn.Module):
src = src + self.dropout(self.feed_forward_macaron(src)) src = src + self.dropout(self.feed_forward_macaron(src))
# emformer attention module # emformer attention module
( (src_att, output_memory, attn_cache,) = self._apply_attention_module_infer(
src_att,
output_memory,
attn_cache,
) = self._apply_attention_module_infer(
src, R, memory, attn_cache, padding_mask=padding_mask src, R, memory, attn_cache, padding_mask=padding_mask
) )
src = src + self.dropout(src_att) src = src + self.dropout(src_att)

View File

@ -927,11 +927,7 @@ class EmformerEncoderLayer(nn.Module):
] ]
else: else:
memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
( (output_right_context_utterance, next_key, next_val,) = self.attention.infer(
output_right_context_utterance,
next_key,
next_val,
) = self.attention.infer(
utterance=utterance, utterance=utterance,
right_context=right_context, right_context=right_context,
memory=pre_memory, memory=pre_memory,

View File

@ -1502,11 +1502,7 @@ class EmformerEncoder(nn.Module):
end = start + 4 end = start + 4
cache = states[start:end] cache = states[start:end]
( (output, right_context, output_cache,) = layer.infer(
output,
right_context,
output_cache,
) = layer.infer(
output, output,
right_context, right_context,
padding_mask=None, padding_mask=None,

View File

@ -374,11 +374,7 @@ def streaming_forward(
Returns encoder outputs, output lengths, and updated states. Returns encoder outputs, output lengths, and updated states.
""" """
cached_embed_left_pad = states[-2] cached_embed_left_pad = states[-2]
( (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
x,
x_lens,
new_cached_embed_left_pad,
) = model.encoder_embed.streaming_forward(
x=features, x=features,
x_lens=feature_lens, x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad, cached_left_pad=cached_embed_left_pad,

View File

@ -20,6 +20,7 @@ kaldialign==0.7.1
sentencepiece==0.1.96 sentencepiece==0.1.96
tensorboard==2.8.0 tensorboard==2.8.0
typeguard==2.13.3 typeguard==2.13.3
black==22.3.0
multi_quantization multi_quantization
onnx onnx

View File

@ -5,3 +5,4 @@ sentencepiece>=0.1.96
tensorboard tensorboard
typeguard typeguard
dill dill
black==22.3.0