add adapter in streaming_forward

This commit is contained in:
marcoyang 2024-03-19 18:39:39 +08:00
parent 413220d6a4
commit 045b900670

View File

@ -1004,6 +1004,9 @@ class Zipformer2EncoderLayer(nn.Module):
)
src = src + self_attn
if self.use_adapters and self.post_sa_adapter is not None:
src = self.post_sa_adapter(src)
src_conv, cached_conv1 = self.conv_module1.streaming_forward(
src,
cache=cached_conv1,
@ -1016,6 +1019,9 @@ class Zipformer2EncoderLayer(nn.Module):
# bypass in the middle of the layer.
src = self.bypass_mid(src_orig, src)
if self.use_adapters and self.mid_adapter is not None:
src = self.mid_adapter(src)
self_attn, cached_val2 = self.self_attn2.streaming_forward(
src,
attn_weights=attn_weights,
@ -1031,12 +1037,18 @@ class Zipformer2EncoderLayer(nn.Module):
)
src = src + src_conv
if self.use_adapters and self.post_conv_adapter is not None:
src = self.post_conv_adapter(src)
src = src + self.feed_forward3(src)
src = self.norm(src)
src = self.bypass(src_orig, src)
if self.use_adapters and self.adapter is not None:
src = self.adapter(src)
return (
src,
cached_key,