Some fixes/refactoring, make parameters shared

This commit is contained in:
Daniel Povey 2022-04-25 13:55:27 +08:00
parent 0d40b4617a
commit bb7cb82b04

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple from typing import Optional, Tuple
@ -87,7 +86,10 @@ class Conformer(EncoderInterface):
self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer( # Pass in a lambda that creates a new ConformerEncoderLayer with these
# args. Don't use deepcopy because we need the knowledge_base
# to be shared.
encoder_layer_fn = lambda: ConformerEncoderLayer(
self.knowledge_base, self.knowledge_base,
d_model, d_model,
nhead, nhead,
@ -100,7 +102,7 @@ class Conformer(EncoderInterface):
knowledge_D, knowledge_D,
knowledge_K knowledge_K
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers)
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -307,10 +309,10 @@ class ConformerEncoder(nn.Module):
>>> out = conformer_encoder(src, pos_emb) >>> out = conformer_encoder(src, pos_emb)
""" """
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: def __init__(self, encoder_layer_fn, num_layers: int) -> None:
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [encoder_layer_fn() for i in range(num_layers)]
) )
self.num_layers = num_layers self.num_layers = num_layers