modify usage of the model device in train.py

This commit is contained in:
yaozengwei 2022-05-06 22:03:49 +08:00
parent a72048be3e
commit dd439b1906

View File

@ -479,7 +479,7 @@ def load_checkpoint_if_available(
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
@ -529,7 +529,7 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
@ -553,7 +553,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -609,7 +613,7 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
@ -643,7 +647,7 @@ def compute_validation_loss(
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
@ -865,7 +869,6 @@ def run(rank, world_size, args):
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Eve(model.parameters(), lr=params.initial_lr)
@ -986,7 +989,7 @@ def run(rank, world_size, args):
def scan_pessimistic_batches_for_oom(
model: nn.Module,
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,