Some fixes/refactoring, make parameters shared

This commit is contained in:
Daniel Povey 2022-04-25 13:55:27 +08:00
parent 0d40b4617a
commit bb7cb82b04

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import warnings
from typing import Optional, Tuple
@ -87,7 +86,10 @@ class Conformer(EncoderInterface):
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer(
# Pass in a lambda that creates a new ConformerEncoderLayer with these
# args. Don't use deepcopy because we need the knowledge_base
# to be shared.
encoder_layer_fn = lambda: ConformerEncoderLayer(
self.knowledge_base,
d_model,
nhead,
@ -100,7 +102,7 @@ class Conformer(EncoderInterface):
knowledge_D,
knowledge_K
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -307,10 +309,10 @@ class ConformerEncoder(nn.Module):
>>> out = conformer_encoder(src, pos_emb)
"""
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
def __init__(self, encoder_layer_fn, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
[encoder_layer_fn() for i in range(num_layers)]
)
self.num_layers = num_layers