mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
check some files
This commit is contained in:
parent
187d59d59b
commit
fecceee216
@ -93,7 +93,9 @@ def fast_beam_search(
|
|||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1), project_input=False
|
current_encoder_out.unsqueeze(2),
|
||||||
|
decoder_out.unsqueeze(1),
|
||||||
|
project_input=False,
|
||||||
)
|
)
|
||||||
logits = logits.squeeze(1).squeeze(1)
|
logits = logits.squeeze(1).squeeze(1)
|
||||||
log_probs = logits.log_softmax(dim=-1)
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
@ -140,7 +142,6 @@ def greedy_search(
|
|||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
|
|
||||||
|
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
t = 0
|
t = 0
|
||||||
hyp = [blank_id] * context_size
|
hyp = [blank_id] * context_size
|
||||||
@ -163,9 +164,9 @@ def greedy_search(
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
logits = model.joiner(current_encoder_out,
|
logits = model.joiner(
|
||||||
decoder_out.unsqueeze(1),
|
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||||
project_input=False)
|
)
|
||||||
# logits is (1, 1, 1, vocab_size)
|
# logits is (1, 1, 1, vocab_size)
|
||||||
|
|
||||||
y = logits.argmax().item()
|
y = logits.argmax().item()
|
||||||
@ -228,8 +229,9 @@ def greedy_search_batch(
|
|||||||
for t in range(T):
|
for t in range(T):
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1),
|
logits = model.joiner(
|
||||||
project_input=False)
|
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||||
|
)
|
||||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||||
|
|
||||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||||
@ -466,7 +468,6 @@ def modified_beam_search(
|
|||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||||
|
|
||||||
|
|
||||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||||
# as index, so we use `to(torch.int64)` below.
|
# as index, so we use `to(torch.int64)` below.
|
||||||
current_encoder_out = torch.index_select(
|
current_encoder_out = torch.index_select(
|
||||||
@ -720,7 +721,7 @@ def beam_search(
|
|||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out,
|
current_encoder_out,
|
||||||
decoder_out.unsqueeze(1),
|
decoder_out.unsqueeze(1),
|
||||||
project_input=False
|
project_input=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(fangjun): Scale the blank posterior
|
# TODO(fangjun): Scale the blank posterior
|
||||||
|
@ -17,9 +17,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from scaling import ScaledConv1d, ScaledEmbedding
|
||||||
from typing import Optional
|
|
||||||
from scaling import ScaledConv1d, ScaledLinear, ScaledEmbedding
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user