diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py index 4e4695fa5..8e2dfdd72 100644 --- a/egs/librispeech/ASR/zipformer_adapter/zipformer.py +++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py @@ -1004,6 +1004,9 @@ class Zipformer2EncoderLayer(nn.Module): ) src = src + self_attn + if self.use_adapters and self.post_sa_adapter is not None: + src = self.post_sa_adapter(src) + src_conv, cached_conv1 = self.conv_module1.streaming_forward( src, cache=cached_conv1, @@ -1016,6 +1019,9 @@ class Zipformer2EncoderLayer(nn.Module): # bypass in the middle of the layer. src = self.bypass_mid(src_orig, src) + if self.use_adapters and self.mid_adapter is not None: + src = self.mid_adapter(src) + self_attn, cached_val2 = self.self_attn2.streaming_forward( src, attn_weights=attn_weights, @@ -1031,12 +1037,18 @@ class Zipformer2EncoderLayer(nn.Module): ) src = src + src_conv + if self.use_adapters and self.post_conv_adapter is not None: + src = self.post_conv_adapter(src) + src = src + self.feed_forward3(src) src = self.norm(src) src = self.bypass(src_orig, src) + if self.use_adapters and self.adapter is not None: + src = self.adapter(src) + return ( src, cached_key,