mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
add subformer model with style embeddings
This commit is contained in:
parent
16e8907805
commit
cad01bfcb6
@ -173,12 +173,15 @@ class PromptedTransducer(nn.Module):
|
||||
self.text_encoder_dim = 512
|
||||
self.freeze_text_encoder = freeze_text_encoder
|
||||
|
||||
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lens: torch.Tensor,
|
||||
style_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
@ -243,7 +246,8 @@ class PromptedTransducer(nn.Module):
|
||||
if use_pre_text:
|
||||
memory, memory_key_padding_mask = self.encode_text(
|
||||
text,
|
||||
text_lens,
|
||||
text_lens=text_lens,
|
||||
style_lens=style_lens,
|
||||
)
|
||||
assert not memory.isnan().any(), memory
|
||||
else:
|
||||
@ -374,14 +378,15 @@ class PromptedTransducer(nn.Module):
|
||||
indicator = indicator.to(memory.dtype)
|
||||
|
||||
extra_term = torch.zeros_like(memory)
|
||||
extra_term[..., 0] += indicator
|
||||
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim)
|
||||
|
||||
return memory + extra_term
|
||||
|
||||
def encode_text(
|
||||
self,
|
||||
text: Tensor,
|
||||
text_lens: Tensor
|
||||
style_lens: Tensor,
|
||||
text_lens: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Get the embeddings of text
|
||||
|
||||
@ -400,6 +405,8 @@ class PromptedTransducer(nn.Module):
|
||||
memory, text_lens = self.text_encoder(
|
||||
text, text_lens, text_key_padding_mask
|
||||
)
|
||||
memory = self._add_style_indicator(memory, style_lens)
|
||||
|
||||
return memory, text_key_padding_mask
|
||||
|
||||
def encode_audio(
|
||||
|
Loading…
x
Reference in New Issue
Block a user