mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
check some files
This commit is contained in:
parent
187d59d59b
commit
fecceee216
@ -93,7 +93,9 @@ def fast_beam_search(
|
||||
)
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1), project_input=False
|
||||
current_encoder_out.unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False,
|
||||
)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
@ -140,7 +142,6 @@ def greedy_search(
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
hyp = [blank_id] * context_size
|
||||
@ -163,9 +164,9 @@ def greedy_search(
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# fmt: on
|
||||
logits = model.joiner(current_encoder_out,
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False)
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||
)
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
y = logits.argmax().item()
|
||||
@ -228,8 +229,9 @@ def greedy_search_batch(
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1),
|
||||
project_input=False)
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||
)
|
||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||
@ -466,7 +468,6 @@ def modified_beam_search(
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
@ -720,7 +721,7 @@ def beam_search(
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False
|
||||
project_input=False,
|
||||
)
|
||||
|
||||
# TODO(fangjun): Scale the blank posterior
|
||||
|
@ -17,9 +17,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from typing import Optional
|
||||
from scaling import ScaledConv1d, ScaledLinear, ScaledEmbedding
|
||||
from scaling import ScaledConv1d, ScaledEmbedding
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user