mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
Reduce scale on decorrelation by 5, to 0.01
This commit is contained in:
parent
7c6d923d3f
commit
53ca61db7a
@ -17,11 +17,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from scaling import (
|
from scaling import ScaledConv1d, ScaledEmbedding
|
||||||
ScaledConv1d,
|
|
||||||
ScaledEmbedding,
|
|
||||||
Decorrelate,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
@ -63,9 +59,6 @@ class Decoder(nn.Module):
|
|||||||
embedding_dim=decoder_dim,
|
embedding_dim=decoder_dim,
|
||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
)
|
)
|
||||||
self.decorrelate = Decorrelate(apply_prob=0.25,
|
|
||||||
dropout_rate=0.05)
|
|
||||||
|
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
|
||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
@ -106,6 +99,5 @@ class Decoder(nn.Module):
|
|||||||
assert embedding_out.size(-1) == self.context_size
|
assert embedding_out.size(-1) == self.context_size
|
||||||
embedding_out = self.conv(embedding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
embedding_out = self.decorrelate(embedding_out)
|
|
||||||
embedding_out = F.relu(embedding_out)
|
embedding_out = F.relu(embedding_out)
|
||||||
return embedding_out
|
return embedding_out
|
||||||
|
@ -199,7 +199,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
self.decorrelate = Decorrelate(apply_prob=0.25, dropout_rate=0.05)
|
self.decorrelate = Decorrelate(apply_prob=0.25,
|
||||||
|
dropout_rate=0.01)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user