modify usage of the model device in train.py

This commit is contained in:
yaozengwei 2022-05-06 22:03:49 +08:00
parent a72048be3e
commit dd439b1906

View File

@ -479,7 +479,7 @@ def load_checkpoint_if_available(
def save_checkpoint( def save_checkpoint(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
@ -529,7 +529,7 @@ def save_checkpoint(
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
@ -553,7 +553,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -609,7 +613,7 @@ def compute_loss(
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
@ -643,7 +647,7 @@ def compute_validation_loss(
def train_one_epoch( def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
@ -865,7 +869,6 @@ def run(rank, world_size, args):
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Eve(model.parameters(), lr=params.initial_lr) optimizer = Eve(model.parameters(), lr=params.initial_lr)
@ -986,7 +989,7 @@ def run(rank, world_size, args):
def scan_pessimistic_batches_for_oom( def scan_pessimistic_batches_for_oom(
model: nn.Module, model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,