mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 01:24:19 +00:00
Add decorrelation to joiner
This commit is contained in:
parent
cd6b707e2b
commit
7c6d923d3f
@ -17,7 +17,11 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from scaling import ScaledConv1d, ScaledEmbedding
|
from scaling import (
|
||||||
|
ScaledConv1d,
|
||||||
|
ScaledEmbedding,
|
||||||
|
Decorrelate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
@ -59,6 +63,9 @@ class Decoder(nn.Module):
|
|||||||
embedding_dim=decoder_dim,
|
embedding_dim=decoder_dim,
|
||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
)
|
)
|
||||||
|
self.decorrelate = Decorrelate(apply_prob=0.25,
|
||||||
|
dropout_rate=0.05)
|
||||||
|
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
|
||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
@ -99,5 +106,6 @@ class Decoder(nn.Module):
|
|||||||
assert embedding_out.size(-1) == self.context_size
|
assert embedding_out.size(-1) == self.context_size
|
||||||
embedding_out = self.conv(embedding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
|
embedding_out = self.decorrelate(embedding_out)
|
||||||
embedding_out = F.relu(embedding_out)
|
embedding_out = F.relu(embedding_out)
|
||||||
return embedding_out
|
return embedding_out
|
||||||
|
Loading…
x
Reference in New Issue
Block a user