Fit type hint

This commit is contained in:
yfy62 2023-04-26 17:45:00 +08:00
parent 23a9b66295
commit 5e7e1a350e

View File

@ -564,7 +564,7 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
batch: dict,
@ -948,7 +948,7 @@ def run(rank, world_size, args):
valid_cuts = gigaspeech.dev_cuts()
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
if 0 and not params.print_diagnostics:
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,