Minor fixes for torch.jit.script support (#1329)

This commit is contained in:
zr_jin 2023-10-24 01:10:50 +08:00 committed by GitHub
parent 902dc2364a
commit 92ef561ff7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 1 deletions

View File

@ -70,6 +70,10 @@ class Decoder(nn.Module):
groups=embedding_dim,
bias=False,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""

View File

@ -95,6 +95,10 @@ class Decoder(nn.Module):
max_abs=1.0,
prob=0.05,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""

View File

@ -74,6 +74,10 @@ class Decoder(nn.Module):
groups=embedding_dim,
bias=False,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
self.output_linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:

View File

@ -71,6 +71,10 @@ class Decoder(nn.Module):
groups=embedding_dim,
bias=False,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""

View File

@ -17,7 +17,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaling import Balancer
@ -95,6 +94,10 @@ class Decoder(nn.Module):
max_abs=1.0,
prob=0.05,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""