From f36ebad618cd6cb7e1223c22fcd257c27f16bfe9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 17 Jul 2022 04:24:48 +0800 Subject: [PATCH] Remove 2/3 StructuredLinear/StructuredConv1d modules, use linear/conv1d --- .../ASR/pruned_transducer_stateless7/conformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 491416ec1..aa0fb22b8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -466,7 +466,7 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" - self.in_proj = StructuredLinear((embed_dim,), (3, embed_dim), bias=True) + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0) self.proj_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.0, max_positive=1.0, max_abs=10.0) @@ -544,8 +544,8 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb, self.embed_dim, self.num_heads, - self.in_proj.get_weight(), - self.in_proj.get_bias(), + self.in_proj.weight, + self.in_proj.bias, self.dropout, self.out_proj.weight, self.out_proj.bias, @@ -881,9 +881,9 @@ class ConvolutionModule(nn.Module): # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 - self.pointwise_conv1 = StructuredConv1d( - (channels,), - (2, channels), + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, kernel_size=1, stride=1, padding=0,