mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
do some changes
This commit is contained in:
parent
f1b7ab0226
commit
18b8c670dd
@ -38,7 +38,7 @@ class Decoder(nn.Module):
|
||||
embedding_dim: int,
|
||||
blank_id: int,
|
||||
context_size: int,
|
||||
modify_embedding_layer: bool,
|
||||
modify_embedding_layer: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -52,7 +52,7 @@ class Decoder(nn.Module):
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
modify_embedding_layer:
|
||||
Do or not do modify the decoder embedding layer.
|
||||
Decide to modify the decoder embedding layer or not.
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(
|
||||
@ -61,13 +61,12 @@ class Decoder(nn.Module):
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
self.modify_embedding_layer = modify_embedding_layer
|
||||
self.embedding_dim_inv_sqrt = embedding_dim ** 0.5
|
||||
if self.modify_embedding_layer:
|
||||
self.embedding.weight = nn.Parameter(
|
||||
(
|
||||
torch.randn(vocab_size, embedding_dim)
|
||||
* (embedding_dim ** -0.5)
|
||||
).detach()
|
||||
embedding_weight = torch.randn(vocab_size, embedding_dim) * (
|
||||
1 / self.embedding_dim_inv_sqrt
|
||||
)
|
||||
self.embedding.weight = nn.Parameter(embedding_weight)
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.blank_id = blank_id
|
||||
@ -98,7 +97,7 @@ class Decoder(nn.Module):
|
||||
"""
|
||||
embedding_out = self.embedding(y)
|
||||
if self.modify_embedding_layer:
|
||||
embedding_out = self.embedding(y) * (self.embedding_dim ** 0.5)
|
||||
embedding_out = self.embedding(y) * self.embedding_dim_inv_sqrt
|
||||
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
|
@ -182,7 +182,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--modify-embedding",
|
||||
type=bool,
|
||||
default=True,
|
||||
default=False,
|
||||
help="When True, we modify the decoder embedding layer."
|
||||
"When False, we don't modify the decoder embedding layer.",
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user