check codebook index range

This commit is contained in:
Guo Liyong 2022-04-29 10:54:11 +08:00
parent 5aaf981d46
commit 76e56fa28f
2 changed files with 9 additions and 1 deletions

View File

@ -91,7 +91,11 @@ def compute_codeindices(
)
# [N, T, C]
codebook_indices = codebook_indices.to("cpu").numpy().astype(np.int8)
codebook_indices = codebook_indices.to("cpu").numpy()
assert np.all(
codebook_indices[np.where(codebook_indices < 0)] == -100
)
assert np.max(codebook_indices) < 256
supervisions = batch["supervisions"]
cut_list = supervisions["cut"]

View File

@ -133,6 +133,10 @@ class Transducer(nn.Module):
if self.training:
# Do distillation.
assert codebook_indices is not None
assert torch.all(
codebook_indices[torch.where(codebook_indices < 0)] == -100
)
assert torch.max(codebook_indices) < 256
assert hasattr(self, "codebook_loss_net")
# Output rate of hubert is 50 frames per second,