mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
add subformer model with style embeddings
This commit is contained in:
parent
16e8907805
commit
cad01bfcb6
@ -172,6 +172,8 @@ class PromptedTransducer(nn.Module):
|
|||||||
|
|
||||||
self.text_encoder_dim = 512
|
self.text_encoder_dim = 512
|
||||||
self.freeze_text_encoder = freeze_text_encoder
|
self.freeze_text_encoder = freeze_text_encoder
|
||||||
|
|
||||||
|
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -179,6 +181,7 @@ class PromptedTransducer(nn.Module):
|
|||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
text: torch.Tensor,
|
text: torch.Tensor,
|
||||||
text_lens: torch.Tensor,
|
text_lens: torch.Tensor,
|
||||||
|
style_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
@ -243,7 +246,8 @@ class PromptedTransducer(nn.Module):
|
|||||||
if use_pre_text:
|
if use_pre_text:
|
||||||
memory, memory_key_padding_mask = self.encode_text(
|
memory, memory_key_padding_mask = self.encode_text(
|
||||||
text,
|
text,
|
||||||
text_lens,
|
text_lens=text_lens,
|
||||||
|
style_lens=style_lens,
|
||||||
)
|
)
|
||||||
assert not memory.isnan().any(), memory
|
assert not memory.isnan().any(), memory
|
||||||
else:
|
else:
|
||||||
@ -374,14 +378,15 @@ class PromptedTransducer(nn.Module):
|
|||||||
indicator = indicator.to(memory.dtype)
|
indicator = indicator.to(memory.dtype)
|
||||||
|
|
||||||
extra_term = torch.zeros_like(memory)
|
extra_term = torch.zeros_like(memory)
|
||||||
extra_term[..., 0] += indicator
|
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim)
|
||||||
|
|
||||||
return memory + extra_term
|
return memory + extra_term
|
||||||
|
|
||||||
def encode_text(
|
def encode_text(
|
||||||
self,
|
self,
|
||||||
text: Tensor,
|
text: Tensor,
|
||||||
text_lens: Tensor
|
style_lens: Tensor,
|
||||||
|
text_lens: Tensor,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Get the embeddings of text
|
"""Get the embeddings of text
|
||||||
|
|
||||||
@ -400,6 +405,8 @@ class PromptedTransducer(nn.Module):
|
|||||||
memory, text_lens = self.text_encoder(
|
memory, text_lens = self.text_encoder(
|
||||||
text, text_lens, text_key_padding_mask
|
text, text_lens, text_key_padding_mask
|
||||||
)
|
)
|
||||||
|
memory = self._add_style_indicator(memory, style_lens)
|
||||||
|
|
||||||
return memory, text_key_padding_mask
|
return memory, text_key_padding_mask
|
||||||
|
|
||||||
def encode_audio(
|
def encode_audio(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user