mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
check codebook index range
This commit is contained in:
parent
5aaf981d46
commit
76e56fa28f
@ -91,7 +91,11 @@ def compute_codeindices(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# [N, T, C]
|
# [N, T, C]
|
||||||
codebook_indices = codebook_indices.to("cpu").numpy().astype(np.int8)
|
codebook_indices = codebook_indices.to("cpu").numpy()
|
||||||
|
assert np.all(
|
||||||
|
codebook_indices[np.where(codebook_indices < 0)] == -100
|
||||||
|
)
|
||||||
|
assert np.max(codebook_indices) < 256
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
cut_list = supervisions["cut"]
|
cut_list = supervisions["cut"]
|
||||||
|
@ -133,6 +133,10 @@ class Transducer(nn.Module):
|
|||||||
if self.training:
|
if self.training:
|
||||||
# Do distillation.
|
# Do distillation.
|
||||||
assert codebook_indices is not None
|
assert codebook_indices is not None
|
||||||
|
assert torch.all(
|
||||||
|
codebook_indices[torch.where(codebook_indices < 0)] == -100
|
||||||
|
)
|
||||||
|
assert torch.max(codebook_indices) < 256
|
||||||
assert hasattr(self, "codebook_loss_net")
|
assert hasattr(self, "codebook_loss_net")
|
||||||
|
|
||||||
# Output rate of hubert is 50 frames per second,
|
# Output rate of hubert is 50 frames per second,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user