Fit type hint

This commit is contained in:
yfy62 2023-04-26 17:45:00 +08:00
parent 23a9b66295
commit 5e7e1a350e

View File

@ -564,7 +564,7 @@ def save_checkpoint(
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
pl: UniqLexicon, pl: UniqLexicon,
batch: dict, batch: dict,
@ -948,7 +948,7 @@ def run(rank, world_size, args):
valid_cuts = gigaspeech.dev_cuts() valid_cuts = gigaspeech.dev_cuts()
valid_dl = gigaspeech.valid_dataloaders(valid_cuts) valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
if 0 and not params.print_diagnostics: if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,