Fix bug with indicator

This commit is contained in:
Daniel Povey 2023-05-02 13:36:03 +08:00
parent c207c55e94
commit 75e9f1a34a

View File

@ -155,11 +155,11 @@ class PromptedTransducer(nn.Module):
text = self.text_embed(text) # now (T, N, C)
text_key_padding_mask = make_pad_mask(text_lens)
memory = self._add_style_indicator(memory, style_lens)
memory, text_lens = self.text_encoder(text, text_lens,
text_key_padding_mask)
memory = self._add_style_indicator(memory, style_lens)
memory_key_padding_mask = make_pad_mask(text_lens)
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask,
@ -253,10 +253,9 @@ class PromptedTransducer(nn.Module):
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
"""
Adds to `memory` an indicator that is 0.1 for positions that correspond to
Adds to `memory` an indicator that is 1.0 for positions that correspond to
the `style prompt` and 0 elsewhere. The scale can be fixed because the
scale of the memory vector can adjust to compensate (within limits set
by the balancers)..
scale of the embedding vector can adjust to compensate.
Args:
memory: (memory_len, batch_size, embed_dim)
@ -267,8 +266,11 @@ class PromptedTransducer(nn.Module):
indicator = torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens
indicator = indicator.to(memory.dtype).unsqueeze(-1)
indicator = indicator.to(memory.dtype)
return memory + indicator
extra_term = torch.zeros_like(memory)
extra_term[..., 0] += indicator
return memory + extra_term
Transducer = PromptedTransducer # for decoding