mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make weight in SimpleCombine a vector
This commit is contained in:
parent
e08f5c1bce
commit
0379ab57a2
@ -341,11 +341,6 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
feedforward_dim,
|
feedforward_dim,
|
||||||
dropout)
|
dropout)
|
||||||
|
|
||||||
self.feed_forward3 = FeedforwardModule(d_model,
|
|
||||||
feedforward_dim,
|
|
||||||
dropout)
|
|
||||||
|
|
||||||
|
|
||||||
self.conv_module1 = ConvolutionModule(d_model,
|
self.conv_module1 = ConvolutionModule(d_model,
|
||||||
cnn_module_kernel)
|
cnn_module_kernel)
|
||||||
|
|
||||||
@ -451,16 +446,13 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||||
src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
|
src = src + self.conv_module1(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
src = src + self.feed_forward2(src)
|
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.self_attn.forward2(src, attn_weights)
|
src = src + self.self_attn.forward2(src, attn_weights)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||||
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
|
src = src + self.conv_module2(src, src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
src = src + self.feed_forward3(src)
|
src = src + self.feed_forward2(src)
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
@ -855,7 +847,7 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
min_weight: Tuple[float] = (0., 0.)):
|
min_weight: Tuple[float] = (0., 0.)):
|
||||||
super(SimpleCombiner, self).__init__()
|
super(SimpleCombiner, self).__init__()
|
||||||
assert dim2 >= dim1
|
assert dim2 >= dim1
|
||||||
self.weight1 = nn.Parameter(torch.zeros(()))
|
self.weight1 = nn.Parameter(torch.ones(dim2) * min_weight[0])
|
||||||
self.min_weight = min_weight
|
self.min_weight = min_weight
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -878,9 +870,6 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
max=1.0-self.min_weight[1])
|
max=1.0-self.min_weight[1])
|
||||||
|
|
||||||
|
|
||||||
src1 = src1 * weight1
|
|
||||||
src2 = src2 * (1.0 - weight1)
|
|
||||||
|
|
||||||
src1_dim = src1.shape[-1]
|
src1_dim = src1.shape[-1]
|
||||||
src2_dim = src2.shape[-1]
|
src2_dim = src2.shape[-1]
|
||||||
if src1_dim != src2_dim:
|
if src1_dim != src2_dim:
|
||||||
@ -893,6 +882,8 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
src1 = src1[:src2_dim]
|
src1 = src1[:src2_dim]
|
||||||
|
|
||||||
|
src1 = src1 * weight1
|
||||||
|
src2 = src2 * (1.0 - weight1)
|
||||||
|
|
||||||
return src1 + src2
|
return src1 + src2
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user