From 1551d7e82625eb8d319eef588e9c79b5f8c17bb6 Mon Sep 17 00:00:00 2001 From: yifanyang Date: Tue, 7 Feb 2023 19:43:51 +0800 Subject: [PATCH] modify the implement of computing k --- .../ASR/pruned_transducer_stateless9/beam_search.py | 3 ++- egs/librispeech/ASR/pruned_transducer_stateless9/model.py | 6 ++++-- egs/librispeech/ASR/pruned_transducer_stateless9/train.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py index 2397f0d41..155de536d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py @@ -709,7 +709,8 @@ def greedy_search_batch( [h[-context_size - 1] for h in hyps[:batch_size]], device=device, dtype=torch.int64, - ) == decoder_input[:, -1], + ) + == decoder_input[:, -1], k + 1, torch.zeros(N, 1, device=device, dtype=torch.int64), ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/model.py b/egs/librispeech/ASR/pruned_transducer_stateless9/model.py index 9e35692d6..29601807c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/model.py @@ -87,7 +87,9 @@ class Transducer(nn.Module): Returns: Return a tensor of shape (N, U). """ - y_shift = F.pad(y, (context_size, 0), mode="constant", value=self.decoder.blank_id)[:, :-context_size] + y_shift = F.pad( + y, (context_size, 0), mode="constant", value=self.decoder.blank_id + )[:, :-context_size] mask = y_shift != y T_arange = torch.arange(y.size(1)).expand_as(y).to(device=y.device) @@ -154,7 +156,7 @@ class Transducer(nn.Module): # compute k k = self._compute_k(sos_y_padded, context_size=self.decoder.context_size) - + # decoder_out: [B, S + 1, decoder_dim] decoder_out = self.decoder(sos_y_padded, k) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/train.py b/egs/librispeech/ASR/pruned_transducer_stateless9/train.py index 86d79e48f..e85cca573 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/train.py @@ -235,7 +235,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless7/exp", + default="pruned_transducer_stateless9/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved