mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
modify usage of the model device in train.py
This commit is contained in:
parent
a72048be3e
commit
dd439b1906
@ -479,7 +479,7 @@ def load_checkpoint_if_available(
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, DDP],
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
@ -529,7 +529,7 @@ def save_checkpoint(
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, DDP],
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
@ -553,7 +553,11 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = model.device
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -609,7 +613,7 @@ def compute_loss(
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, DDP],
|
||||
sp: spm.SentencePieceProcessor,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
@ -643,7 +647,7 @@ def compute_validation_loss(
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
@ -865,7 +869,6 @@ def run(rank, world_size, args):
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank])
|
||||
model.device = device
|
||||
|
||||
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
||||
|
||||
@ -986,7 +989,7 @@ def run(rank, world_size, args):
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, DDP],
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user