This commit is contained in:
Guo Liyong 2022-05-27 09:54:48 +08:00
parent d8f68abff8
commit 14a4d1d6f2

View File

@ -134,7 +134,7 @@ class Transducer(nn.Module):
if self.training and codebook_indexes is not None: if self.training and codebook_indexes is not None:
assert hasattr(self, "codebook_loss_net") assert hasattr(self, "codebook_loss_net")
if codebook_indexes.shape[1] != middle_layer_output.shape[1]: if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
codebook_indexes = self.concat_sucessive_codebook_indexes( codebook_indexes = self.concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes middle_layer_output, codebook_indexes
) )
codebook_loss = self.codebook_loss_net( codebook_loss = self.codebook_loss_net(
@ -221,7 +221,7 @@ class Transducer(nn.Module):
return (simple_loss, pruned_loss, codebook_loss) return (simple_loss, pruned_loss, codebook_loss)
@staticmethod @staticmethod
def concat_sucessive_codebook_indexes( def concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes middle_layer_output, codebook_indexes
): ):
# Output rate of hubert is 50 frames per second, # Output rate of hubert is 50 frames per second,