mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove linear layers from RandomCombine
This commit is contained in:
parent
c92d9d72aa
commit
42667aacf9
@ -306,7 +306,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
num_channels = encoder_layer.norm_final.num_channels
|
num_channels = encoder_layer.norm_final.num_channels
|
||||||
self.combiner = RandomCombine(
|
self.combiner = RandomCombine(
|
||||||
num_inputs=len(self.aux_layers),
|
num_inputs=len(self.aux_layers),
|
||||||
num_channels=num_channels,
|
|
||||||
final_weight=0.5,
|
final_weight=0.5,
|
||||||
pure_prob=0.333,
|
pure_prob=0.333,
|
||||||
stddev=2.0,
|
stddev=2.0,
|
||||||
@ -1081,7 +1080,6 @@ class RandomCombine(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inputs: int,
|
num_inputs: int,
|
||||||
num_channels: int,
|
|
||||||
final_weight: float = 0.5,
|
final_weight: float = 0.5,
|
||||||
pure_prob: float = 0.5,
|
pure_prob: float = 0.5,
|
||||||
stddev: float = 2.0,
|
stddev: float = 2.0,
|
||||||
@ -1092,8 +1090,6 @@ class RandomCombine(nn.Module):
|
|||||||
The number of tensor inputs, which equals the number of layers'
|
The number of tensor inputs, which equals the number of layers'
|
||||||
outputs that are fed into this module. E.g. in an 18-layer neural
|
outputs that are fed into this module. E.g. in an 18-layer neural
|
||||||
net if we output layers 16, 12, 18, num_inputs would be 3.
|
net if we output layers 16, 12, 18, num_inputs would be 3.
|
||||||
num_channels:
|
|
||||||
The number of channels on the input, e.g. 512.
|
|
||||||
final_weight:
|
final_weight:
|
||||||
The amount of weight or probability we assign to the
|
The amount of weight or probability we assign to the
|
||||||
final layer when randomly choosing layers or when choosing
|
final layer when randomly choosing layers or when choosing
|
||||||
@ -1124,13 +1120,6 @@ class RandomCombine(nn.Module):
|
|||||||
assert 0 < final_weight < 1, final_weight
|
assert 0 < final_weight < 1, final_weight
|
||||||
assert num_inputs >= 1
|
assert num_inputs >= 1
|
||||||
|
|
||||||
self.linear = nn.ModuleList(
|
|
||||||
[
|
|
||||||
nn.Linear(num_channels, num_channels, bias=True)
|
|
||||||
for _ in range(num_inputs - 1)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_inputs = num_inputs
|
self.num_inputs = num_inputs
|
||||||
self.final_weight = final_weight
|
self.final_weight = final_weight
|
||||||
self.pure_prob = pure_prob
|
self.pure_prob = pure_prob
|
||||||
@ -1143,12 +1132,7 @@ class RandomCombine(nn.Module):
|
|||||||
.log()
|
.log()
|
||||||
.item()
|
.item()
|
||||||
)
|
)
|
||||||
self._reset_parameters()
|
|
||||||
|
|
||||||
def _reset_parameters(self):
|
|
||||||
for i in range(len(self.linear)):
|
|
||||||
nn.init.eye_(self.linear[i].weight)
|
|
||||||
nn.init.constant_(self.linear[i].bias, 0.0)
|
|
||||||
|
|
||||||
def forward(self, inputs: List[Tensor]) -> Tensor:
|
def forward(self, inputs: List[Tensor]) -> Tensor:
|
||||||
"""Forward function.
|
"""Forward function.
|
||||||
@ -1169,14 +1153,9 @@ class RandomCombine(nn.Module):
|
|||||||
num_channels = inputs[0].shape[-1]
|
num_channels = inputs[0].shape[-1]
|
||||||
num_frames = inputs[0].numel() // num_channels
|
num_frames = inputs[0].numel() // num_channels
|
||||||
|
|
||||||
mod_inputs = []
|
|
||||||
for i in range(num_inputs - 1):
|
|
||||||
mod_inputs.append(self.linear[i](inputs[i]))
|
|
||||||
mod_inputs.append(inputs[num_inputs - 1])
|
|
||||||
|
|
||||||
ndim = inputs[0].ndim
|
ndim = inputs[0].ndim
|
||||||
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
||||||
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape(
|
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
|
||||||
(num_frames, num_channels, num_inputs)
|
(num_frames, num_channels, num_inputs)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1290,7 +1269,6 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
|
|||||||
num_channels = 50
|
num_channels = 50
|
||||||
m = RandomCombine(
|
m = RandomCombine(
|
||||||
num_inputs=num_inputs,
|
num_inputs=num_inputs,
|
||||||
num_channels=num_channels,
|
|
||||||
final_weight=final_weight,
|
final_weight=final_weight,
|
||||||
pure_prob=pure_prob,
|
pure_prob=pure_prob,
|
||||||
stddev=stddev,
|
stddev=stddev,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user