diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py b/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py index 9059757fe..b04803470 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py @@ -92,11 +92,11 @@ class InputEmbedding(nn.Module): def forward( self, - x: float["b n d"], - cond: float["b n d"], - text_embed: float["b n d"], + x: float["b n d"], # noqa: F722 + cond: float["b n d"], # noqa: F722 + text_embed: float["b n d"], # noqa: F722 drop_audio_cond=False, - ): # noqa: F722 + ): if drop_audio_cond: # cfg for cond audio cond = torch.zeros_like(cond)