Modification about random combine (#452)

* comment some lines, random combine from 1/3 layers, on linear layers in combiner

* delete commented lines

* minor change
This commit is contained in:
Zengwei Yao 2022-06-30 12:23:49 +08:00 committed by GitHub
parent c10aec5656
commit d80f29e662
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -87,10 +87,17 @@ class Conformer(EncoderInterface):
layer_dropout, layer_dropout,
cnn_module_kernel, cnn_module_kernel,
) )
# aux_layers from 1/3
self.encoder = ConformerEncoder( self.encoder = ConformerEncoder(
encoder_layer, encoder_layer,
num_encoder_layers, num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), aux_layers=list(
range(
num_encoder_layers // 3,
num_encoder_layers - 1,
aux_layer_period,
)
),
) )
def forward( def forward(
@ -295,10 +302,8 @@ class ConformerEncoder(nn.Module):
assert num_layers - 1 not in aux_layers assert num_layers - 1 not in aux_layers
self.aux_layers = aux_layers + [num_layers - 1] self.aux_layers = aux_layers + [num_layers - 1]
num_channels = encoder_layer.norm_final.num_channels
self.combiner = RandomCombine( self.combiner = RandomCombine(
num_inputs=len(self.aux_layers), num_inputs=len(self.aux_layers),
num_channels=num_channels,
final_weight=0.5, final_weight=0.5,
pure_prob=0.333, pure_prob=0.333,
stddev=2.0, stddev=2.0,
@ -1072,7 +1077,6 @@ class RandomCombine(nn.Module):
def __init__( def __init__(
self, self,
num_inputs: int, num_inputs: int,
num_channels: int,
final_weight: float = 0.5, final_weight: float = 0.5,
pure_prob: float = 0.5, pure_prob: float = 0.5,
stddev: float = 2.0, stddev: float = 2.0,
@ -1083,8 +1087,6 @@ class RandomCombine(nn.Module):
The number of tensor inputs, which equals the number of layers' The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3. net if we output layers 16, 12, 18, num_inputs would be 3.
num_channels:
The number of channels on the input, e.g. 512.
final_weight: final_weight:
The amount of weight or probability we assign to the The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing final layer when randomly choosing layers or when choosing
@ -1115,13 +1117,6 @@ class RandomCombine(nn.Module):
assert 0 < final_weight < 1, final_weight assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1 assert num_inputs >= 1
self.linear = nn.ModuleList(
[
nn.Linear(num_channels, num_channels, bias=True)
for _ in range(num_inputs - 1)
]
)
self.num_inputs = num_inputs self.num_inputs = num_inputs
self.final_weight = final_weight self.final_weight = final_weight
self.pure_prob = pure_prob self.pure_prob = pure_prob
@ -1134,12 +1129,6 @@ class RandomCombine(nn.Module):
.log() .log()
.item() .item()
) )
self._reset_parameters()
def _reset_parameters(self):
for i in range(len(self.linear)):
nn.init.eye_(self.linear[i].weight)
nn.init.constant_(self.linear[i].bias, 0.0)
def forward(self, inputs: List[Tensor]) -> Tensor: def forward(self, inputs: List[Tensor]) -> Tensor:
"""Forward function. """Forward function.
@ -1160,28 +1149,9 @@ class RandomCombine(nn.Module):
num_channels = inputs[0].shape[-1] num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels num_frames = inputs[0].numel() // num_channels
mod_inputs = []
if False:
# It throws the following error for torch 1.6.0 when using
# torch script.
#
# Expected integer literal for index. ModuleList/Sequential
# indexing is only supported with integer literals. Enumeration is
# supported, e.g. 'for index, v in enumerate(self): ...':
# for i in range(num_inputs - 1):
# mod_inputs.append(self.linear[i](inputs[i]))
assert False
else:
for i, linear in enumerate(self.linear):
if i < num_inputs - 1:
mod_inputs.append(linear(inputs[i]))
mod_inputs.append(inputs[num_inputs - 1])
ndim = inputs[0].ndim ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs) # stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape( stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs) (num_frames, num_channels, num_inputs)
) )