Update type hint

This commit is contained in:
yfy62 2023-04-26 17:39:07 +08:00
parent 9cbe54732a
commit c6b4159dcc

View File

@ -514,7 +514,7 @@ def load_checkpoint_if_available(
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
@ -642,7 +642,7 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
valid_dl: torch.utils.data.DataLoader,
@ -1012,7 +1012,7 @@ def run(rank, world_size, args):
def scan_pessimistic_batches_for_oom(
model: nn.Module,
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,