check some files

This commit is contained in:
luomingshuang 2022-04-11 20:38:10 +08:00
parent 187d59d59b
commit fecceee216
2 changed files with 11 additions and 12 deletions

View File

@ -93,7 +93,9 @@ def fast_beam_search(
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1), project_input=False
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
@ -140,7 +142,6 @@ def greedy_search(
encoder_out = model.joiner.encoder_proj(encoder_out)
T = encoder_out.size(1)
t = 0
hyp = [blank_id] * context_size
@ -163,9 +164,9 @@ def greedy_search(
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
logits = model.joiner(current_encoder_out,
decoder_out.unsqueeze(1),
project_input=False)
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
@ -228,8 +229,9 @@ def greedy_search_batch(
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1),
project_input=False)
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
# logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
@ -466,7 +468,6 @@ def modified_beam_search(
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
@ -720,7 +721,7 @@ def beam_search(
logits = model.joiner(
current_encoder_out,
decoder_out.unsqueeze(1),
project_input=False
project_input=False,
)
# TODO(fangjun): Scale the blank posterior

View File

@ -17,9 +17,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional
from scaling import ScaledConv1d, ScaledLinear, ScaledEmbedding
from scaling import ScaledConv1d, ScaledEmbedding
class Decoder(nn.Module):