add subformer model with style embeddings

This commit is contained in:
marcoyang1998 2023-08-29 16:04:51 +08:00
parent 16e8907805
commit cad01bfcb6

View File

@ -173,12 +173,15 @@ class PromptedTransducer(nn.Module):
self.text_encoder_dim = 512
self.freeze_text_encoder = freeze_text_encoder
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5))
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
text: torch.Tensor,
text_lens: torch.Tensor,
style_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
@ -243,7 +246,8 @@ class PromptedTransducer(nn.Module):
if use_pre_text:
memory, memory_key_padding_mask = self.encode_text(
text,
text_lens,
text_lens=text_lens,
style_lens=style_lens,
)
assert not memory.isnan().any(), memory
else:
@ -374,14 +378,15 @@ class PromptedTransducer(nn.Module):
indicator = indicator.to(memory.dtype)
extra_term = torch.zeros_like(memory)
extra_term[..., 0] += indicator
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim)
return memory + extra_term
def encode_text(
self,
text: Tensor,
text_lens: Tensor
style_lens: Tensor,
text_lens: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Get the embeddings of text
@ -400,6 +405,8 @@ class PromptedTransducer(nn.Module):
memory, text_lens = self.text_encoder(
text, text_lens, text_key_padding_mask
)
memory = self._add_style_indicator(memory, style_lens)
return memory, text_key_padding_mask
def encode_audio(