Reduce lr_scales of soem sub modules

This commit is contained in:
Daniel Povey 2023-01-05 18:50:04 +08:00
parent 90c02b471c
commit ccc38a97f7

View File

@ -1086,6 +1086,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(4000.0, 0.0)) (4000.0, 0.0))
) -> None: ) -> None:
super().__init__() super().__init__()
self.lr_scale = 0.75
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.query_head_dim = query_head_dim self.query_head_dim = query_head_dim
@ -1336,6 +1337,9 @@ class AttentionSqueeze(nn.Module):
hidden_dim: int, hidden_dim: int,
bottleneck_dim: int = 16): bottleneck_dim: int = 16):
super().__init__() super().__init__()
self.lr_scale = 0.5
self.bottleneck_dim = bottleneck_dim self.bottleneck_dim = bottleneck_dim
self.in_proj = nn.Linear(embed_dim, hidden_dim, self.in_proj = nn.Linear(embed_dim, hidden_dim,
@ -1476,6 +1480,8 @@ class NonlinAttention(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.lr_scale = 0.75
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.in_proj = nn.Linear(channels, hidden_channels * 2, bias=True) self.in_proj = nn.Linear(channels, hidden_channels * 2, bias=True)