From c6b4159dccb015a246a1be79cfce8748365e0e54 Mon Sep 17 00:00:00 2001 From: yfy62 Date: Wed, 26 Apr 2023 17:39:07 +0800 Subject: [PATCH] Update type hint --- egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 2ce785ef8..09a3b6a1b 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -514,7 +514,7 @@ def load_checkpoint_if_available( def save_checkpoint( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, @@ -642,7 +642,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, pl: UniqLexicon, valid_dl: torch.utils.data.DataLoader, @@ -1012,7 +1012,7 @@ def run(rank, world_size, args): def scan_pessimistic_batches_for_oom( - model: nn.Module, + model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor,