Modification about random combine (#452)

* comment some lines, random combine from 1/3 layers, on linear layers in combiner

* delete commented lines

* minor change
This commit is contained in:
Zengwei Yao 2022-06-30 12:23:49 +08:00 committed by GitHub
parent c10aec5656
commit d80f29e662
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -87,10 +87,17 @@ class Conformer(EncoderInterface):
layer_dropout,
cnn_module_kernel,
)
# aux_layers from 1/3
self.encoder = ConformerEncoder(
encoder_layer,
num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
aux_layers=list(
range(
num_encoder_layers // 3,
num_encoder_layers - 1,
aux_layer_period,
)
),
)
def forward(
@ -295,10 +302,8 @@ class ConformerEncoder(nn.Module):
assert num_layers - 1 not in aux_layers
self.aux_layers = aux_layers + [num_layers - 1]
num_channels = encoder_layer.norm_final.num_channels
self.combiner = RandomCombine(
num_inputs=len(self.aux_layers),
num_channels=num_channels,
final_weight=0.5,
pure_prob=0.333,
stddev=2.0,
@ -1072,7 +1077,6 @@ class RandomCombine(nn.Module):
def __init__(
self,
num_inputs: int,
num_channels: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0,
@ -1083,8 +1087,6 @@ class RandomCombine(nn.Module):
The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
num_channels:
The number of channels on the input, e.g. 512.
final_weight:
The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
@ -1115,13 +1117,6 @@ class RandomCombine(nn.Module):
assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1
self.linear = nn.ModuleList(
[
nn.Linear(num_channels, num_channels, bias=True)
for _ in range(num_inputs - 1)
]
)
self.num_inputs = num_inputs
self.final_weight = final_weight
self.pure_prob = pure_prob
@ -1134,12 +1129,6 @@ class RandomCombine(nn.Module):
.log()
.item()
)
self._reset_parameters()
def _reset_parameters(self):
for i in range(len(self.linear)):
nn.init.eye_(self.linear[i].weight)
nn.init.constant_(self.linear[i].bias, 0.0)
def forward(self, inputs: List[Tensor]) -> Tensor:
"""Forward function.
@ -1160,28 +1149,9 @@ class RandomCombine(nn.Module):
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels
mod_inputs = []
if False:
# It throws the following error for torch 1.6.0 when using
# torch script.
#
# Expected integer literal for index. ModuleList/Sequential
# indexing is only supported with integer literals. Enumeration is
# supported, e.g. 'for index, v in enumerate(self): ...':
# for i in range(num_inputs - 1):
# mod_inputs.append(self.linear[i](inputs[i]))
assert False
else:
for i, linear in enumerate(self.linear):
if i < num_inputs - 1:
mod_inputs.append(linear(inputs[i]))
mod_inputs.append(inputs[num_inputs - 1])
ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape(
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs)
)