mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix bug with indicator
This commit is contained in:
parent
c207c55e94
commit
75e9f1a34a
@ -155,11 +155,11 @@ class PromptedTransducer(nn.Module):
|
||||
text = self.text_embed(text) # now (T, N, C)
|
||||
text_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
memory = self._add_style_indicator(memory, style_lens)
|
||||
|
||||
memory, text_lens = self.text_encoder(text, text_lens,
|
||||
text_key_padding_mask)
|
||||
|
||||
memory = self._add_style_indicator(memory, style_lens)
|
||||
|
||||
memory_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask,
|
||||
@ -253,10 +253,9 @@ class PromptedTransducer(nn.Module):
|
||||
|
||||
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
|
||||
"""
|
||||
Adds to `memory` an indicator that is 0.1 for positions that correspond to
|
||||
Adds to `memory` an indicator that is 1.0 for positions that correspond to
|
||||
the `style prompt` and 0 elsewhere. The scale can be fixed because the
|
||||
scale of the memory vector can adjust to compensate (within limits set
|
||||
by the balancers)..
|
||||
scale of the embedding vector can adjust to compensate.
|
||||
|
||||
Args:
|
||||
memory: (memory_len, batch_size, embed_dim)
|
||||
@ -267,8 +266,11 @@ class PromptedTransducer(nn.Module):
|
||||
|
||||
|
||||
indicator = torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens
|
||||
indicator = indicator.to(memory.dtype).unsqueeze(-1)
|
||||
indicator = indicator.to(memory.dtype)
|
||||
|
||||
return memory + indicator
|
||||
extra_term = torch.zeros_like(memory)
|
||||
extra_term[..., 0] += indicator
|
||||
|
||||
return memory + extra_term
|
||||
|
||||
Transducer = PromptedTransducer # for decoding
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user