mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
modify the implement of computing k
This commit is contained in:
parent
6e377c6874
commit
1551d7e826
@ -709,7 +709,8 @@ def greedy_search_batch(
|
||||
[h[-context_size - 1] for h in hyps[:batch_size]],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) == decoder_input[:, -1],
|
||||
)
|
||||
== decoder_input[:, -1],
|
||||
k + 1,
|
||||
torch.zeros(N, 1, device=device, dtype=torch.int64),
|
||||
)
|
||||
|
@ -87,7 +87,9 @@ class Transducer(nn.Module):
|
||||
Returns:
|
||||
Return a tensor of shape (N, U).
|
||||
"""
|
||||
y_shift = F.pad(y, (context_size, 0), mode="constant", value=self.decoder.blank_id)[:, :-context_size]
|
||||
y_shift = F.pad(
|
||||
y, (context_size, 0), mode="constant", value=self.decoder.blank_id
|
||||
)[:, :-context_size]
|
||||
mask = y_shift != y
|
||||
|
||||
T_arange = torch.arange(y.size(1)).expand_as(y).to(device=y.device)
|
||||
@ -154,7 +156,7 @@ class Transducer(nn.Module):
|
||||
|
||||
# compute k
|
||||
k = self._compute_k(sos_y_padded, context_size=self.decoder.context_size)
|
||||
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded, k)
|
||||
|
||||
|
@ -235,7 +235,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
default="pruned_transducer_stateless9/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
Loading…
x
Reference in New Issue
Block a user