Remove 2/3 StructuredLinear/StructuredConv1d modules, use linear/conv1d

This commit is contained in:
Daniel Povey 2022-07-17 04:24:48 +08:00
parent 7e88e2a0e9
commit f36ebad618

View File

@ -466,7 +466,7 @@ class RelPositionMultiheadAttention(nn.Module):
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = StructuredLinear((embed_dim,), (3, embed_dim), bias=True)
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
self.proj_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.0,
max_positive=1.0, max_abs=10.0)
@ -544,8 +544,8 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb,
self.embed_dim,
self.num_heads,
self.in_proj.get_weight(),
self.in_proj.get_bias(),
self.in_proj.weight,
self.in_proj.bias,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
@ -881,9 +881,9 @@ class ConvolutionModule(nn.Module):
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = StructuredConv1d(
(channels,),
(2, channels),
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,