modify the implement of computing k

This commit is contained in:
yifanyang 2023-02-07 19:43:51 +08:00
parent 6e377c6874
commit 1551d7e826
3 changed files with 7 additions and 4 deletions

View File

@ -709,7 +709,8 @@ def greedy_search_batch(
[h[-context_size - 1] for h in hyps[:batch_size]],
device=device,
dtype=torch.int64,
) == decoder_input[:, -1],
)
== decoder_input[:, -1],
k + 1,
torch.zeros(N, 1, device=device, dtype=torch.int64),
)

View File

@ -87,7 +87,9 @@ class Transducer(nn.Module):
Returns:
Return a tensor of shape (N, U).
"""
y_shift = F.pad(y, (context_size, 0), mode="constant", value=self.decoder.blank_id)[:, :-context_size]
y_shift = F.pad(
y, (context_size, 0), mode="constant", value=self.decoder.blank_id
)[:, :-context_size]
mask = y_shift != y
T_arange = torch.arange(y.size(1)).expand_as(y).to(device=y.device)
@ -154,7 +156,7 @@ class Transducer(nn.Module):
# compute k
k = self._compute_k(sos_y_padded, context_size=self.decoder.context_size)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded, k)

View File

@ -235,7 +235,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
default="pruned_transducer_stateless9/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved