potential fix to k2SSL hubert_ce encoder forward function

This commit is contained in:
jianyou 2025-03-31 18:02:02 +08:00
parent db9fb8ad31
commit 90fb9207ba

View File

@ -429,7 +429,7 @@ class HubertModel(nn.Module):
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x = x.transpose(0, 1)
x, x_lens = self.encoder(x, (~padding_mask).sum(dim=-1))
x, x_lens = self.encoder(x, (~padding_mask).sum(dim=-1) if padding_mask else None)
x = x.transpose(0, 1)
if features_only: