mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Minor fixes for torch.jit.script support (#1329)
This commit is contained in:
parent
902dc2364a
commit
92ef561ff7
@ -70,6 +70,10 @@ class Decoder(nn.Module):
|
||||
groups=embedding_dim,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
|
||||
# when inference with torch.jit.script and context_size == 1
|
||||
self.conv = nn.Identity()
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
|
@ -95,6 +95,10 @@ class Decoder(nn.Module):
|
||||
max_abs=1.0,
|
||||
prob=0.05,
|
||||
)
|
||||
else:
|
||||
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
|
||||
# when inference with torch.jit.script and context_size == 1
|
||||
self.conv = nn.Identity()
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
|
@ -74,6 +74,10 @@ class Decoder(nn.Module):
|
||||
groups=embedding_dim,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
|
||||
# when inference with torch.jit.script and context_size == 1
|
||||
self.conv = nn.Identity()
|
||||
self.output_linear = nn.Linear(embedding_dim, vocab_size)
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
|
@ -71,6 +71,10 @@ class Decoder(nn.Module):
|
||||
groups=embedding_dim,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
|
||||
# when inference with torch.jit.script and context_size == 1
|
||||
self.conv = nn.Identity()
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
|
@ -17,7 +17,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scaling import Balancer
|
||||
|
||||
|
||||
@ -95,6 +94,10 @@ class Decoder(nn.Module):
|
||||
max_abs=1.0,
|
||||
prob=0.05,
|
||||
)
|
||||
else:
|
||||
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
|
||||
# when inference with torch.jit.script and context_size == 1
|
||||
self.conv = nn.Identity()
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user