From cad01bfcb63f44e156241ce368177473d5373fe4 Mon Sep 17 00:00:00 2001 From: marcoyang1998 Date: Tue, 29 Aug 2023 16:04:51 +0800 Subject: [PATCH] add subformer model with style embeddings --- .../zipformer_prompt_asr/model_with_subformer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_subformer.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_subformer.py index 1fa150791..b67addbd8 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_subformer.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_subformer.py @@ -172,6 +172,8 @@ class PromptedTransducer(nn.Module): self.text_encoder_dim = 512 self.freeze_text_encoder = freeze_text_encoder + + self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5)) def forward( self, @@ -179,6 +181,7 @@ class PromptedTransducer(nn.Module): x_lens: torch.Tensor, text: torch.Tensor, text_lens: torch.Tensor, + style_lens: torch.Tensor, y: k2.RaggedTensor, prune_range: int = 5, am_scale: float = 0.0, @@ -243,7 +246,8 @@ class PromptedTransducer(nn.Module): if use_pre_text: memory, memory_key_padding_mask = self.encode_text( text, - text_lens, + text_lens=text_lens, + style_lens=style_lens, ) assert not memory.isnan().any(), memory else: @@ -374,14 +378,15 @@ class PromptedTransducer(nn.Module): indicator = indicator.to(memory.dtype) extra_term = torch.zeros_like(memory) - extra_term[..., 0] += indicator + extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim) return memory + extra_term def encode_text( self, text: Tensor, - text_lens: Tensor + style_lens: Tensor, + text_lens: Tensor, ) -> Tuple[Tensor, Tensor]: """Get the embeddings of text @@ -400,6 +405,8 @@ class PromptedTransducer(nn.Module): memory, text_lens = self.text_encoder( text, text_lens, text_key_padding_mask ) + memory = self._add_style_indicator(memory, style_lens) + return memory, text_key_padding_mask def encode_audio(