Remove linear layers from RandomCombine

This commit is contained in:
Daniel Povey 2022-06-10 11:38:50 +08:00
parent c92d9d72aa
commit 42667aacf9

View File

@ -306,7 +306,6 @@ class ConformerEncoder(nn.Module):
num_channels = encoder_layer.norm_final.num_channels num_channels = encoder_layer.norm_final.num_channels
self.combiner = RandomCombine( self.combiner = RandomCombine(
num_inputs=len(self.aux_layers), num_inputs=len(self.aux_layers),
num_channels=num_channels,
final_weight=0.5, final_weight=0.5,
pure_prob=0.333, pure_prob=0.333,
stddev=2.0, stddev=2.0,
@ -1081,7 +1080,6 @@ class RandomCombine(nn.Module):
def __init__( def __init__(
self, self,
num_inputs: int, num_inputs: int,
num_channels: int,
final_weight: float = 0.5, final_weight: float = 0.5,
pure_prob: float = 0.5, pure_prob: float = 0.5,
stddev: float = 2.0, stddev: float = 2.0,
@ -1092,8 +1090,6 @@ class RandomCombine(nn.Module):
The number of tensor inputs, which equals the number of layers' The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3. net if we output layers 16, 12, 18, num_inputs would be 3.
num_channels:
The number of channels on the input, e.g. 512.
final_weight: final_weight:
The amount of weight or probability we assign to the The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing final layer when randomly choosing layers or when choosing
@ -1124,13 +1120,6 @@ class RandomCombine(nn.Module):
assert 0 < final_weight < 1, final_weight assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1 assert num_inputs >= 1
self.linear = nn.ModuleList(
[
nn.Linear(num_channels, num_channels, bias=True)
for _ in range(num_inputs - 1)
]
)
self.num_inputs = num_inputs self.num_inputs = num_inputs
self.final_weight = final_weight self.final_weight = final_weight
self.pure_prob = pure_prob self.pure_prob = pure_prob
@ -1143,12 +1132,7 @@ class RandomCombine(nn.Module):
.log() .log()
.item() .item()
) )
self._reset_parameters()
def _reset_parameters(self):
for i in range(len(self.linear)):
nn.init.eye_(self.linear[i].weight)
nn.init.constant_(self.linear[i].bias, 0.0)
def forward(self, inputs: List[Tensor]) -> Tensor: def forward(self, inputs: List[Tensor]) -> Tensor:
"""Forward function. """Forward function.
@ -1169,14 +1153,9 @@ class RandomCombine(nn.Module):
num_channels = inputs[0].shape[-1] num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels num_frames = inputs[0].numel() // num_channels
mod_inputs = []
for i in range(num_inputs - 1):
mod_inputs.append(self.linear[i](inputs[i]))
mod_inputs.append(inputs[num_inputs - 1])
ndim = inputs[0].ndim ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs) # stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape( stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs) (num_frames, num_channels, num_inputs)
) )
@ -1290,7 +1269,6 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
num_channels = 50 num_channels = 50
m = RandomCombine( m = RandomCombine(
num_inputs=num_inputs, num_inputs=num_inputs,
num_channels=num_channels,
final_weight=final_weight, final_weight=final_weight,
pure_prob=pure_prob, pure_prob=pure_prob,
stddev=stddev, stddev=stddev,