mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Change how scales are applied; fix residual bug
This commit is contained in:
parent
bec33e6855
commit
5eafccb369
@ -229,10 +229,17 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)[0]
|
)[0]
|
||||||
src = residual + self.dropout(src_att)
|
# natural rms scale of mha output is about 2 to 6. scaling down by 0.1 takes it
|
||||||
|
# to 0.2 to 0.6, which is suitable to add to the inputs assuming the output
|
||||||
|
# of the previous convolution layer had a magnitude of around 1.0
|
||||||
|
# (this magnitude of 1.0, or a bit less, like 0.3, is learned but is
|
||||||
|
# dictated by considerations of what is done to the output of the
|
||||||
|
# encoder.
|
||||||
|
post_scale_mha = 0.1
|
||||||
|
src = residual + post_scale_mha * self.dropout(src_att)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = residual + self.dropout(self.conv_module(self.scale_conv(src)))
|
src = src + self.dropout(self.conv_module(self.scale_conv(src)))
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(self.scale_ff(src)))
|
src = src + self.dropout(self.feed_forward(self.scale_ff(src)))
|
||||||
@ -891,13 +898,15 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
|
|
||||||
|
# TODO: can have a learned scale in here, or a fixed one.
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
# x is (batch, channels, time)
|
# x is (batch, channels, time)
|
||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
x = self.scale(x)
|
x = self.scale(x)
|
||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
x = self.activation(x)
|
|
||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
return x.permute(2, 0, 1)
|
return x.permute(2, 0, 1)
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5_pbs_cinit",
|
default="transducer_stateless/randcombine1_expscale3_rework",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user