This commit is contained in:
Guo Liyong 2022-05-27 09:54:48 +08:00
parent d8f68abff8
commit 14a4d1d6f2

View File

@ -134,7 +134,7 @@ class Transducer(nn.Module):
if self.training and codebook_indexes is not None:
assert hasattr(self, "codebook_loss_net")
if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
codebook_indexes = self.concat_sucessive_codebook_indexes(
codebook_indexes = self.concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
)
codebook_loss = self.codebook_loss_net(
@ -221,7 +221,7 @@ class Transducer(nn.Module):
return (simple_loss, pruned_loss, codebook_loss)
@staticmethod
def concat_sucessive_codebook_indexes(
def concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
):
# Output rate of hubert is 50 frames per second,