mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add adapter in streaming_forward
This commit is contained in:
parent
413220d6a4
commit
045b900670
@ -1004,6 +1004,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
src = src + self_attn
|
src = src + self_attn
|
||||||
|
|
||||||
|
if self.use_adapters and self.post_sa_adapter is not None:
|
||||||
|
src = self.post_sa_adapter(src)
|
||||||
|
|
||||||
src_conv, cached_conv1 = self.conv_module1.streaming_forward(
|
src_conv, cached_conv1 = self.conv_module1.streaming_forward(
|
||||||
src,
|
src,
|
||||||
cache=cached_conv1,
|
cache=cached_conv1,
|
||||||
@ -1016,6 +1019,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
# bypass in the middle of the layer.
|
# bypass in the middle of the layer.
|
||||||
src = self.bypass_mid(src_orig, src)
|
src = self.bypass_mid(src_orig, src)
|
||||||
|
|
||||||
|
if self.use_adapters and self.mid_adapter is not None:
|
||||||
|
src = self.mid_adapter(src)
|
||||||
|
|
||||||
self_attn, cached_val2 = self.self_attn2.streaming_forward(
|
self_attn, cached_val2 = self.self_attn2.streaming_forward(
|
||||||
src,
|
src,
|
||||||
attn_weights=attn_weights,
|
attn_weights=attn_weights,
|
||||||
@ -1031,12 +1037,18 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
src = src + src_conv
|
src = src + src_conv
|
||||||
|
|
||||||
|
if self.use_adapters and self.post_conv_adapter is not None:
|
||||||
|
src = self.post_conv_adapter(src)
|
||||||
|
|
||||||
src = src + self.feed_forward3(src)
|
src = src + self.feed_forward3(src)
|
||||||
|
|
||||||
src = self.norm(src)
|
src = self.norm(src)
|
||||||
|
|
||||||
src = self.bypass(src_orig, src)
|
src = self.bypass(src_orig, src)
|
||||||
|
|
||||||
|
if self.use_adapters and self.adapter is not None:
|
||||||
|
src = self.adapter(src)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
src,
|
src,
|
||||||
cached_key,
|
cached_key,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user