minor updates

This commit is contained in:
JinZr 2023-07-27 17:52:49 +08:00
parent d70f6e21f2
commit 634931cb61
3 changed files with 91 additions and 7 deletions

View File

@ -779,6 +779,78 @@ def greedy_search_batch(
) )
def deprecated_greedy_search_batch(
model: nn.Module, encoder_out: torch.Tensor
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = next(model.parameters()).device
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out)
# decoder_out: (batch_size, 1, decoder_out_dim)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
print(current_encoder_out)
print(decoder_out.unsqueeze(1))
print(logits)
exit()
# logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v not in (blank_id, unk_id):
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
ans = [h[context_size:] for h in hyps]
return ans
def deprecated_greedy_search_batch_for_cross_attn( def deprecated_greedy_search_batch_for_cross_attn(
model: nn.Module, model: nn.Module,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
@ -822,23 +894,29 @@ def deprecated_greedy_search_batch_for_cross_attn(
# decoder_out: (batch_size, 1, decoder_out_dim) # decoder_out: (batch_size, 1, decoder_out_dim)
for t in range(T): for t in range(T):
# current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
current_encoder_out = model.joiner.label_level_am_attention( attn_encoder_out = model.joiner.label_level_am_attention(
encoder_out.unsqueeze(2), encoder_out.unsqueeze(2),
decoder_out.unsqueeze(2), decoder_out.unsqueeze(2),
# encoder_out_lens, encoder_out_lens,
None,
) )
# print(encoder_out[:, t : t + 1, :].unsqueeze(2))
# current_encoder_out = torch.zeros_like(current_encoder_out)
logits = model.joiner( logits = model.joiner(
current_encoder_out, current_encoder_out,
decoder_out.unsqueeze(1), decoder_out.unsqueeze(1),
torch.zeros_like(current_encoder_out),
None, None,
apply_attn=False, apply_attn=False,
project_input=False, project_input=False,
) )
# logits'shape (batch_size, 1, 1, vocab_size) # logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
# print(current_encoder_out)
# print(decoder_out)
# print(logits)
# # exit()
assert logits.ndim == 2, logits.shape assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist() y = logits.argmax(dim=1).tolist()
emitted = False emitted = False

View File

@ -39,6 +39,7 @@ class Joiner(nn.Module):
self, self,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
decoder_out: torch.Tensor, decoder_out: torch.Tensor,
attn_encoder_out: torch.Tensor,
lengths: torch.Tensor, lengths: torch.Tensor,
apply_attn: bool = True, apply_attn: bool = True,
project_input: bool = True, project_input: bool = True,
@ -64,14 +65,14 @@ class Joiner(nn.Module):
) )
if apply_attn and lengths is not None: if apply_attn and lengths is not None:
encoder_out = self.label_level_am_attention( attn_encoder_out = self.label_level_am_attention(
encoder_out, decoder_out, lengths encoder_out, decoder_out, lengths
) )
if project_input: if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else: else:
logit = encoder_out + decoder_out logit = encoder_out + decoder_out + attn_encoder_out
logit = self.output_linear(torch.tanh(logit)) logit = self.output_linear(torch.tanh(logit))

View File

@ -265,7 +265,12 @@ class AsrModel(nn.Module):
# project_input=False since we applied the decoder's input projections # project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner( logits = self.joiner(
am_pruned, lm_pruned, encoder_out_lens, apply_attn=True, project_input=False am_pruned,
lm_pruned,
None,
encoder_out_lens,
apply_attn=True,
project_input=False,
) )
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):