From bb7cb82b04c53de7c241b2b8c3f14ddee90352cb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Apr 2022 13:55:27 +0800 Subject: [PATCH] Some fixes/refactoring, make parameters shared --- egs/librispeech/ASR/pruned2_knowledge/conformer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index e07aba60b..83be579ec 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import math import warnings from typing import Optional, Tuple @@ -87,7 +86,10 @@ class Conformer(EncoderInterface): self.encoder_pos = RelPositionalEncoding(d_model, dropout) - encoder_layer = ConformerEncoderLayer( + # Pass in a lambda that creates a new ConformerEncoderLayer with these + # args. Don't use deepcopy because we need the knowledge_base + # to be shared. + encoder_layer_fn = lambda: ConformerEncoderLayer( self.knowledge_base, d_model, nhead, @@ -100,7 +102,7 @@ class Conformer(EncoderInterface): knowledge_D, knowledge_K ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -307,10 +309,10 @@ class ConformerEncoder(nn.Module): >>> out = conformer_encoder(src, pos_emb) """ - def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + def __init__(self, encoder_layer_fn, num_layers: int) -> None: super().__init__() self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] + [encoder_layer_fn() for i in range(num_layers)] ) self.num_layers = num_layers