Fix bug with indicator

This commit is contained in:
Daniel Povey 2023-05-02 13:36:03 +08:00
parent c207c55e94
commit 75e9f1a34a

View File

@ -155,11 +155,11 @@ class PromptedTransducer(nn.Module):
text = self.text_embed(text) # now (T, N, C) text = self.text_embed(text) # now (T, N, C)
text_key_padding_mask = make_pad_mask(text_lens) text_key_padding_mask = make_pad_mask(text_lens)
memory = self._add_style_indicator(memory, style_lens)
memory, text_lens = self.text_encoder(text, text_lens, memory, text_lens = self.text_encoder(text, text_lens,
text_key_padding_mask) text_key_padding_mask)
memory = self._add_style_indicator(memory, style_lens)
memory_key_padding_mask = make_pad_mask(text_lens) memory_key_padding_mask = make_pad_mask(text_lens)
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask, encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask,
@ -253,10 +253,9 @@ class PromptedTransducer(nn.Module):
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor): def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
""" """
Adds to `memory` an indicator that is 0.1 for positions that correspond to Adds to `memory` an indicator that is 1.0 for positions that correspond to
the `style prompt` and 0 elsewhere. The scale can be fixed because the the `style prompt` and 0 elsewhere. The scale can be fixed because the
scale of the memory vector can adjust to compensate (within limits set scale of the embedding vector can adjust to compensate.
by the balancers)..
Args: Args:
memory: (memory_len, batch_size, embed_dim) memory: (memory_len, batch_size, embed_dim)
@ -267,8 +266,11 @@ class PromptedTransducer(nn.Module):
indicator = torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens indicator = torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens
indicator = indicator.to(memory.dtype).unsqueeze(-1) indicator = indicator.to(memory.dtype)
return memory + indicator extra_term = torch.zeros_like(memory)
extra_term[..., 0] += indicator
return memory + extra_term
Transducer = PromptedTransducer # for decoding Transducer = PromptedTransducer # for decoding