mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 00:54:18 +00:00
Add decorrelation to joiner
This commit is contained in:
parent
cd6b707e2b
commit
7c6d923d3f
@ -17,7 +17,11 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scaling import ScaledConv1d, ScaledEmbedding
|
||||
from scaling import (
|
||||
ScaledConv1d,
|
||||
ScaledEmbedding,
|
||||
Decorrelate,
|
||||
)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
@ -59,6 +63,9 @@ class Decoder(nn.Module):
|
||||
embedding_dim=decoder_dim,
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
self.decorrelate = Decorrelate(apply_prob=0.25,
|
||||
dropout_rate=0.05)
|
||||
|
||||
self.blank_id = blank_id
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
@ -99,5 +106,6 @@ class Decoder(nn.Module):
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
embedding_out = self.decorrelate(embedding_out)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
return embedding_out
|
||||
|
Loading…
x
Reference in New Issue
Block a user