mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
Some fixes/refactoring, make parameters shared
This commit is contained in:
parent
0d40b4617a
commit
bb7cb82b04
@ -15,7 +15,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@ -87,7 +86,10 @@ class Conformer(EncoderInterface):
|
|||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
encoder_layer = ConformerEncoderLayer(
|
# Pass in a lambda that creates a new ConformerEncoderLayer with these
|
||||||
|
# args. Don't use deepcopy because we need the knowledge_base
|
||||||
|
# to be shared.
|
||||||
|
encoder_layer_fn = lambda: ConformerEncoderLayer(
|
||||||
self.knowledge_base,
|
self.knowledge_base,
|
||||||
d_model,
|
d_model,
|
||||||
nhead,
|
nhead,
|
||||||
@ -100,7 +102,7 @@ class Conformer(EncoderInterface):
|
|||||||
knowledge_D,
|
knowledge_D,
|
||||||
knowledge_K
|
knowledge_K
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||||
@ -307,10 +309,10 @@ class ConformerEncoder(nn.Module):
|
|||||||
>>> out = conformer_encoder(src, pos_emb)
|
>>> out = conformer_encoder(src, pos_emb)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
def __init__(self, encoder_layer_fn, num_layers: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[encoder_layer_fn() for i in range(num_layers)]
|
||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user