From 76e56fa28fa496ce9b3ecfcf69d23dd4b9fac4d9 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Fri, 29 Apr 2022 10:54:11 +0800 Subject: [PATCH] check codebook index range --- .../vq_pruned_transducer_stateless2/hubert_code_indices.py | 6 +++++- .../ASR/vq_pruned_transducer_stateless2/model.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_code_indices.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_code_indices.py index 75242b331..2af976c86 100755 --- a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_code_indices.py +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_code_indices.py @@ -91,7 +91,11 @@ def compute_codeindices( ) # [N, T, C] - codebook_indices = codebook_indices.to("cpu").numpy().astype(np.int8) + codebook_indices = codebook_indices.to("cpu").numpy() + assert np.all( + codebook_indices[np.where(codebook_indices < 0)] == -100 + ) + assert np.max(codebook_indices) < 256 supervisions = batch["supervisions"] cut_list = supervisions["cut"] diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/model.py index 064bb7e62..d02efba54 100644 --- a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/model.py @@ -133,6 +133,10 @@ class Transducer(nn.Module): if self.training: # Do distillation. assert codebook_indices is not None + assert torch.all( + codebook_indices[torch.where(codebook_indices < 0)] == -100 + ) + assert torch.max(codebook_indices) < 256 assert hasattr(self, "codebook_loss_net") # Output rate of hubert is 50 frames per second,