mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add adapter in streaming_forward
This commit is contained in:
parent
413220d6a4
commit
045b900670
@ -1004,6 +1004,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
)
|
||||
src = src + self_attn
|
||||
|
||||
if self.use_adapters and self.post_sa_adapter is not None:
|
||||
src = self.post_sa_adapter(src)
|
||||
|
||||
src_conv, cached_conv1 = self.conv_module1.streaming_forward(
|
||||
src,
|
||||
cache=cached_conv1,
|
||||
@ -1016,6 +1019,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
# bypass in the middle of the layer.
|
||||
src = self.bypass_mid(src_orig, src)
|
||||
|
||||
if self.use_adapters and self.mid_adapter is not None:
|
||||
src = self.mid_adapter(src)
|
||||
|
||||
self_attn, cached_val2 = self.self_attn2.streaming_forward(
|
||||
src,
|
||||
attn_weights=attn_weights,
|
||||
@ -1031,12 +1037,18 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
)
|
||||
src = src + src_conv
|
||||
|
||||
if self.use_adapters and self.post_conv_adapter is not None:
|
||||
src = self.post_conv_adapter(src)
|
||||
|
||||
src = src + self.feed_forward3(src)
|
||||
|
||||
src = self.norm(src)
|
||||
|
||||
src = self.bypass(src_orig, src)
|
||||
|
||||
if self.use_adapters and self.adapter is not None:
|
||||
src = self.adapter(src)
|
||||
|
||||
return (
|
||||
src,
|
||||
cached_key,
|
||||
|
Loading…
x
Reference in New Issue
Block a user