diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index cc61b3b32..accf0aa84 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -479,7 +479,7 @@ def load_checkpoint_if_available( def save_checkpoint( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, @@ -529,7 +529,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, @@ -553,7 +553,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -609,7 +613,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, @@ -643,7 +647,7 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, @@ -865,7 +869,6 @@ def run(rank, world_size, args): if world_size > 1: logging.info("Using DDP") model = DDP(model, device_ids=[rank]) - model.device = device optimizer = Eve(model.parameters(), lr=params.initial_lr) @@ -986,7 +989,7 @@ def run(rank, world_size, args): def scan_pessimistic_batches_for_oom( - model: nn.Module, + model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor,