mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Add learnable post-scale for mha
This commit is contained in:
parent
7eb5a84cbe
commit
76a2b9d362
@ -177,6 +177,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2)
|
self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2)
|
||||||
|
self.post_scale_mha = ExpScale(1, speed=10.0, initial_scale=1.0)
|
||||||
self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5)
|
self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5)
|
||||||
self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5)
|
self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5)
|
||||||
self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5)
|
self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5)
|
||||||
@ -230,14 +231,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)[0]
|
)[0]
|
||||||
# natural rms scale of mha output is about 2 to 6. scaling down by 0.1 takes it
|
src = residual + post_scale_mha(self.dropout(src_att))
|
||||||
# to 0.2 to 0.6, which is suitable to add to the inputs assuming the output
|
|
||||||
# of the previous convolution layer had a magnitude of around 1.0
|
|
||||||
# (this magnitude of 1.0, or a bit less, like 0.3, is learned but is
|
|
||||||
# dictated by considerations of what is done to the output of the
|
|
||||||
# encoder.
|
|
||||||
post_scale_mha = 0.1
|
|
||||||
src = residual + post_scale_mha * self.dropout(src_att)
|
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = src + self.dropout(self.conv_module(self.scale_conv(src)))
|
src = src + self.dropout(self.conv_module(self.scale_conv(src)))
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/randcombine1_expscale3_rework_2.0",
|
default="transducer_stateless/randcombine1_expscale3_rework_2.0_b",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user