mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove StructuredLinear,StructuredConv1d
This commit is contained in:
parent
1a184596b6
commit
4a2b940321
@ -168,151 +168,6 @@ class BasicNorm(torch.nn.Module):
|
||||
return x * scales
|
||||
|
||||
|
||||
class StructuredLinear(torch.nn.Module):
|
||||
"""
|
||||
This module mostly behaves like nn.Linear, but the in_features and out_features
|
||||
(the number of input and output channels) are specified as tuples; the
|
||||
actual numbers of channels are products over these tuples.
|
||||
E.g. (2, 256) means 512, with the slowest-varying/largest-stride dims first
|
||||
in terms of the layout.
|
||||
For purposes of the forward() function it will behave the same as if the dim
|
||||
was 512, but the parameter tensors have this structure, which makes
|
||||
a difference if you are using the NeutralGradient optimizer and perhaps
|
||||
certain other optimizers.
|
||||
|
||||
Args:
|
||||
in_features: The number of input channels, specified as
|
||||
a tuple of ints (the number of input channels will be their
|
||||
product). The only difference this makes is that the
|
||||
nn.Parameter tensor will be shaped differently, which may
|
||||
affect some optimizers.
|
||||
out_features: The number of output channels, specified as
|
||||
a tuple of ints.
|
||||
initial_scale: The default initial parameter scale will be
|
||||
multiplied by this.
|
||||
bias: If true, include the bias term.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_features: Tuple[int],
|
||||
out_features: Tuple[int],
|
||||
bias: bool = True,
|
||||
initial_scale: float = 1.0) -> None:
|
||||
super(StructuredLinear, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
in_size = reduce((lambda i,j: i*j), in_features)
|
||||
out_size = reduce((lambda i,j: i*j), out_features)
|
||||
self.weight_shape = (out_size, in_size)
|
||||
self.weight = nn.Parameter(torch.Tensor(*out_features, *in_features))
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(*out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters(initial_scale)
|
||||
|
||||
|
||||
def reset_parameters(self, initial_scale: float = 1.0) -> None:
|
||||
nn.init.kaiming_uniform_(self.weight.reshape(*self.weight_shape), a=(5 ** 0.5))
|
||||
with torch.no_grad():
|
||||
self.weight *= initial_scale
|
||||
nn.init.uniform_(self.bias,
|
||||
-0.1 * initial_scale,
|
||||
0.1 * initial_scale)
|
||||
|
||||
def get_weight(self) -> Tensor:
|
||||
return self.weight.reshape(*self.weight_shape)
|
||||
|
||||
def get_bias(self) -> Optional[Tensor]:
|
||||
return (None if self.bias is None else
|
||||
self.bias.reshape(self.weight_shape[0]))
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return F.linear(input, self.get_weight(), self.get_bias())
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'in_features={}, out_features={}, bias={}'.format(
|
||||
self.in_features, self.out_features, self.bias is not None
|
||||
)
|
||||
|
||||
|
||||
class StructuredConv1d(nn.Conv1d):
|
||||
"""
|
||||
This module mostly behaves like nn.Conv1d, but the
|
||||
in_channels and out_channels are specified as tuples. For example,
|
||||
512 channels might be specified as
|
||||
(2, 256), with slowest-varying/largest-stride dims first in terms of the layout.
|
||||
For purposes of the forward() function it will behave the same as if the dim
|
||||
was 512, but the parameter tensors have this structure, which makes
|
||||
a difference if you are using the NeutralGradient optimizer.
|
||||
|
||||
|
||||
Args:
|
||||
in_channels: The number of input channels, specified as
|
||||
a tuple of ints (the number of input channels will be their
|
||||
product). The only difference this makes is that the
|
||||
nn.Parameter tensor will be shaped differently, which may
|
||||
affect some optimizers.
|
||||
out_channels: The number of output channels, specified as
|
||||
a tuple of ints.
|
||||
initial_scale: The default initial parameter scale will be
|
||||
multiplied by this.
|
||||
bias: If true, include the bias term.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: Tuple[int],
|
||||
out_channels: Tuple[int],
|
||||
*args,
|
||||
initial_scale: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
super(StructuredConv1d, self).__init__(
|
||||
reduce((lambda i,j: i*j), in_channels),
|
||||
reduce((lambda i,j: i*j), out_channels),
|
||||
*args, **kwargs)
|
||||
|
||||
assert self.groups == 1, "Groups not supported as yet"
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if self.transposed:
|
||||
in_channels, out_channels = out_channels, in_channels
|
||||
|
||||
self.weight_shape = self.weight.shape
|
||||
self.weight = nn.Parameter(self.weight.detach().reshape(
|
||||
*out_channels, *in_channels, *self.weight.shape[2:]))
|
||||
|
||||
self.bias_shape = self.bias.shape
|
||||
if self.bias is not None:
|
||||
self.bias = nn.Parameter(self.bias.detach().reshape(
|
||||
*out_channels))
|
||||
|
||||
# These changes in the initialization are the same as for class ScaledConv1d.
|
||||
with torch.no_grad():
|
||||
self.weight[:] *= initial_scale
|
||||
if self.bias is not None:
|
||||
torch.nn.init.uniform_(self.bias,
|
||||
-0.1 * initial_scale,
|
||||
0.1 * initial_scale)
|
||||
|
||||
def get_weight(self) -> Tensor:
|
||||
return self.weight.reshape(*self.weight_shape)
|
||||
def get_bias(self) -> Optional[Tensor]:
|
||||
return (None if self.bias is None else
|
||||
self.bias.reshape(*self.bias_shape))
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
self.get_weight(), self.get_bias(), self.stride,
|
||||
_single(0), self.dilation, self.groups)
|
||||
return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
|
||||
|
||||
def ScaledLinear(*args,
|
||||
initial_scale: float = 1.0,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user