modify the implement of computing k

This commit is contained in:
yifanyang 2023-02-07 19:43:51 +08:00
parent 6e377c6874
commit 1551d7e826
3 changed files with 7 additions and 4 deletions

View File

@ -709,7 +709,8 @@ def greedy_search_batch(
[h[-context_size - 1] for h in hyps[:batch_size]], [h[-context_size - 1] for h in hyps[:batch_size]],
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) == decoder_input[:, -1], )
== decoder_input[:, -1],
k + 1, k + 1,
torch.zeros(N, 1, device=device, dtype=torch.int64), torch.zeros(N, 1, device=device, dtype=torch.int64),
) )

View File

@ -87,7 +87,9 @@ class Transducer(nn.Module):
Returns: Returns:
Return a tensor of shape (N, U). Return a tensor of shape (N, U).
""" """
y_shift = F.pad(y, (context_size, 0), mode="constant", value=self.decoder.blank_id)[:, :-context_size] y_shift = F.pad(
y, (context_size, 0), mode="constant", value=self.decoder.blank_id
)[:, :-context_size]
mask = y_shift != y mask = y_shift != y
T_arange = torch.arange(y.size(1)).expand_as(y).to(device=y.device) T_arange = torch.arange(y.size(1)).expand_as(y).to(device=y.device)
@ -154,7 +156,7 @@ class Transducer(nn.Module):
# compute k # compute k
k = self._compute_k(sos_y_padded, context_size=self.decoder.context_size) k = self._compute_k(sos_y_padded, context_size=self.decoder.context_size)
# decoder_out: [B, S + 1, decoder_dim] # decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded, k) decoder_out = self.decoder(sos_y_padded, k)

View File

@ -235,7 +235,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless7/exp", default="pruned_transducer_stateless9/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved