mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
modify usage of the model device in train.py
This commit is contained in:
parent
a72048be3e
commit
dd439b1906
@ -479,7 +479,7 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
@ -529,7 +529,7 @@ def save_checkpoint(
|
|||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
@ -553,7 +553,11 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = model.device
|
device = (
|
||||||
|
model.device
|
||||||
|
if isinstance(model, DDP)
|
||||||
|
else next(model.parameters()).device
|
||||||
|
)
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -609,7 +613,7 @@ def compute_loss(
|
|||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -643,7 +647,7 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
@ -865,7 +869,6 @@ def run(rank, world_size, args):
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
model.device = device
|
|
||||||
|
|
||||||
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
||||||
|
|
||||||
@ -986,7 +989,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
def scan_pessimistic_batches_for_oom(
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user