Update dit.py

This commit is contained in:
zr_jin 2025-01-27 15:57:05 +08:00 committed by GitHub
parent 59cba78889
commit d679567814
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -92,11 +92,11 @@ class InputEmbedding(nn.Module):
def forward(
self,
x: float["b n d"],
cond: float["b n d"],
text_embed: float["b n d"],
x: float["b n d"], # noqa: F722
cond: float["b n d"], # noqa: F722
text_embed: float["b n d"], # noqa: F722
drop_audio_cond=False,
): # noqa: F722
):
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)