from local

This commit is contained in:
dohe0342 2023-04-11 16:25:31 +09:00
parent dcfd8fb6a8
commit 8c7c3a0171
3 changed files with 3 additions and 2 deletions

View File

@ -76,8 +76,9 @@ class Transducer(nn.Module):
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
self.prompt = torch.randn((50, 512), requires_grad=True)
if prompt:
self.prompt = torch.randn((50, 512), requires_grad=True)
def forward(
self,