mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
fix typo
This commit is contained in:
parent
d8f68abff8
commit
14a4d1d6f2
@ -134,7 +134,7 @@ class Transducer(nn.Module):
|
|||||||
if self.training and codebook_indexes is not None:
|
if self.training and codebook_indexes is not None:
|
||||||
assert hasattr(self, "codebook_loss_net")
|
assert hasattr(self, "codebook_loss_net")
|
||||||
if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
|
if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
|
||||||
codebook_indexes = self.concat_sucessive_codebook_indexes(
|
codebook_indexes = self.concat_successive_codebook_indexes(
|
||||||
middle_layer_output, codebook_indexes
|
middle_layer_output, codebook_indexes
|
||||||
)
|
)
|
||||||
codebook_loss = self.codebook_loss_net(
|
codebook_loss = self.codebook_loss_net(
|
||||||
@ -221,7 +221,7 @@ class Transducer(nn.Module):
|
|||||||
return (simple_loss, pruned_loss, codebook_loss)
|
return (simple_loss, pruned_loss, codebook_loss)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def concat_sucessive_codebook_indexes(
|
def concat_successive_codebook_indexes(
|
||||||
middle_layer_output, codebook_indexes
|
middle_layer_output, codebook_indexes
|
||||||
):
|
):
|
||||||
# Output rate of hubert is 50 frames per second,
|
# Output rate of hubert is 50 frames per second,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user