mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Some fixes/refactoring, make parameters shared
This commit is contained in:
parent
0d40b4617a
commit
bb7cb82b04
@ -15,7 +15,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
@ -87,7 +86,10 @@ class Conformer(EncoderInterface):
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
encoder_layer = ConformerEncoderLayer(
|
||||
# Pass in a lambda that creates a new ConformerEncoderLayer with these
|
||||
# args. Don't use deepcopy because we need the knowledge_base
|
||||
# to be shared.
|
||||
encoder_layer_fn = lambda: ConformerEncoderLayer(
|
||||
self.knowledge_base,
|
||||
d_model,
|
||||
nhead,
|
||||
@ -100,7 +102,7 @@ class Conformer(EncoderInterface):
|
||||
knowledge_D,
|
||||
knowledge_K
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
@ -307,10 +309,10 @@ class ConformerEncoder(nn.Module):
|
||||
>>> out = conformer_encoder(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
||||
def __init__(self, encoder_layer_fn, num_layers: int) -> None:
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
[encoder_layer_fn() for i in range(num_layers)]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user