comment some lines, random combine from 1/3 layers, on linear layers in combiner
This commit is contained in:
parent
2a47db653b
commit
0c455db55d
@ -87,10 +87,17 @@ class Conformer(EncoderInterface):
|
||||
layer_dropout,
|
||||
cnn_module_kernel,
|
||||
)
|
||||
# aux_layers from 1/3
|
||||
self.encoder = ConformerEncoder(
|
||||
encoder_layer,
|
||||
num_encoder_layers,
|
||||
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
||||
aux_layers=list(
|
||||
range(
|
||||
num_encoder_layers // 3,
|
||||
num_encoder_layers - 1,
|
||||
aux_layer_period,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -296,10 +303,10 @@ class ConformerEncoder(nn.Module):
|
||||
assert num_layers - 1 not in aux_layers
|
||||
self.aux_layers = set(aux_layers + [num_layers - 1])
|
||||
|
||||
num_channels = encoder_layer.norm_final.num_channels
|
||||
# num_channels = encoder_layer.norm_final.num_channels
|
||||
self.combiner = RandomCombine(
|
||||
num_inputs=len(self.aux_layers),
|
||||
num_channels=num_channels,
|
||||
# num_channels=num_channels,
|
||||
final_weight=0.5,
|
||||
pure_prob=0.333,
|
||||
stddev=2.0,
|
||||
@ -1073,7 +1080,7 @@ class RandomCombine(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_inputs: int,
|
||||
num_channels: int,
|
||||
# num_channels: int,
|
||||
final_weight: float = 0.5,
|
||||
pure_prob: float = 0.5,
|
||||
stddev: float = 2.0,
|
||||
@ -1116,12 +1123,12 @@ class RandomCombine(nn.Module):
|
||||
assert 0 < final_weight < 1, final_weight
|
||||
assert num_inputs >= 1
|
||||
|
||||
self.linear = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(num_channels, num_channels, bias=True)
|
||||
for _ in range(num_inputs - 1)
|
||||
]
|
||||
)
|
||||
# self.linear = nn.ModuleList(
|
||||
# [
|
||||
# nn.Linear(num_channels, num_channels, bias=True)
|
||||
# for _ in range(num_inputs - 1)
|
||||
# ]
|
||||
# )
|
||||
|
||||
self.num_inputs = num_inputs
|
||||
self.final_weight = final_weight
|
||||
@ -1135,12 +1142,13 @@ class RandomCombine(nn.Module):
|
||||
.log()
|
||||
.item()
|
||||
)
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
for i in range(len(self.linear)):
|
||||
nn.init.eye_(self.linear[i].weight)
|
||||
nn.init.constant_(self.linear[i].bias, 0.0)
|
||||
# self._reset_parameters()
|
||||
|
||||
# def _reset_parameters(self):
|
||||
# for i in range(len(self.linear)):
|
||||
# nn.init.eye_(self.linear[i].weight)
|
||||
# nn.init.constant_(self.linear[i].bias, 0.0)
|
||||
|
||||
def forward(self, inputs: List[Tensor]) -> Tensor:
|
||||
"""Forward function.
|
||||
@ -1163,7 +1171,8 @@ class RandomCombine(nn.Module):
|
||||
|
||||
mod_inputs = []
|
||||
for i in range(num_inputs - 1):
|
||||
mod_inputs.append(self.linear[i](inputs[i]))
|
||||
# mod_inputs.append(self.linear[i](inputs[i]))
|
||||
mod_inputs.append(inputs[i])
|
||||
mod_inputs.append(inputs[num_inputs - 1])
|
||||
|
||||
ndim = inputs[0].ndim
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user