From 75e9f1a34a0213f61419359838eedb2b5aea4a06 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 2 May 2023 13:36:03 +0800 Subject: [PATCH] Fix bug with indicator --- .../ASR/pruned_transducer_stateless7/model.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 63d58fdbd..4993ef6c6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -155,11 +155,11 @@ class PromptedTransducer(nn.Module): text = self.text_embed(text) # now (T, N, C) text_key_padding_mask = make_pad_mask(text_lens) + memory = self._add_style_indicator(memory, style_lens) + memory, text_lens = self.text_encoder(text, text_lens, text_key_padding_mask) - memory = self._add_style_indicator(memory, style_lens) - memory_key_padding_mask = make_pad_mask(text_lens) encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask, @@ -253,10 +253,9 @@ class PromptedTransducer(nn.Module): def _add_style_indicator(self, memory: Tensor, style_lens: Tensor): """ - Adds to `memory` an indicator that is 0.1 for positions that correspond to + Adds to `memory` an indicator that is 1.0 for positions that correspond to the `style prompt` and 0 elsewhere. The scale can be fixed because the - scale of the memory vector can adjust to compensate (within limits set - by the balancers).. + scale of the embedding vector can adjust to compensate. Args: memory: (memory_len, batch_size, embed_dim) @@ -267,8 +266,11 @@ class PromptedTransducer(nn.Module): indicator = torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens - indicator = indicator.to(memory.dtype).unsqueeze(-1) + indicator = indicator.to(memory.dtype) - return memory + indicator + extra_term = torch.zeros_like(memory) + extra_term[..., 0] += indicator + + return memory + extra_term Transducer = PromptedTransducer # for decoding