mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
change for decoder.py
This commit is contained in:
parent
18b8c670dd
commit
eb48ade752
@ -52,7 +52,7 @@ class Decoder(nn.Module):
|
|||||||
Number of previous words to use to predict the next word.
|
Number of previous words to use to predict the next word.
|
||||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
modify_embedding_layer:
|
modify_embedding_layer:
|
||||||
Decide to modify the decoder embedding layer or not.
|
Whether modify the decoder embedding layer or not.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(
|
||||||
@ -61,12 +61,18 @@ class Decoder(nn.Module):
|
|||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
)
|
)
|
||||||
self.modify_embedding_layer = modify_embedding_layer
|
self.modify_embedding_layer = modify_embedding_layer
|
||||||
self.embedding_dim_inv_sqrt = embedding_dim ** 0.5
|
self.embedding_dim_sqrt = embedding_dim ** 0.5
|
||||||
|
embedding_weight = torch.randn(vocab_size, embedding_dim) * (
|
||||||
|
1 / self.embedding_dim_sqrt
|
||||||
|
)
|
||||||
if self.modify_embedding_layer:
|
if self.modify_embedding_layer:
|
||||||
embedding_weight = torch.randn(vocab_size, embedding_dim) * (
|
self.embedding = nn.Embedding(
|
||||||
1 / self.embedding_dim_inv_sqrt
|
num_embeddings=vocab_size,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
padding_idx=blank_id,
|
||||||
|
_weight=embedding_weight,
|
||||||
)
|
)
|
||||||
self.embedding.weight = nn.Parameter(embedding_weight)
|
self.embedding._fill_padding_idx_with_zero()
|
||||||
|
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
@ -97,7 +103,7 @@ class Decoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
embedding_out = self.embedding(y)
|
embedding_out = self.embedding(y)
|
||||||
if self.modify_embedding_layer:
|
if self.modify_embedding_layer:
|
||||||
embedding_out = self.embedding(y) * self.embedding_dim_inv_sqrt
|
embedding_out = self.embedding(y) * self.embedding_dim_sqrt
|
||||||
|
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user