Keep model_avg on cpu (#348)

* keep model_avg on cpu

* explicitly convert model_avg to cpu

* minor fix

* remove device convertion for model_avg

* modify usage of the model device in train.py

* change model.device to next(model.parameters()).device for decoding

* assert params.start_epoch>0

* assert params.start_epoch>0, params.start_epoch
This commit is contained in:
Zengwei Yao 2022-05-07 10:42:34 +08:00 committed by GitHub
parent 8e3c89076e
commit c059ef3169
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 21 deletions

View File

@ -276,7 +276,7 @@ def greedy_search(
context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id)
device = model.device
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device, dtype=torch.int64
@ -350,7 +350,7 @@ def greedy_search_batch(
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
device = next(model.parameters()).device
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
@ -580,7 +580,7 @@ def modified_beam_search(
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
@ -705,7 +705,7 @@ def _deprecated_modified_beam_search(
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
T = encoder_out.size(1)
@ -813,7 +813,7 @@ def beam_search(
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size,

View File

@ -250,7 +250,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
@ -560,7 +560,6 @@ def main():
model.to(device)
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)

View File

@ -125,8 +125,8 @@ def get_parser():
"--start-epoch",
type=int,
default=1,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
@ -479,7 +479,7 @@ def load_checkpoint_if_available(
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
@ -529,7 +529,7 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
@ -553,7 +553,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -609,7 +613,7 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
@ -643,7 +647,7 @@ def compute_validation_loss(
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
@ -857,6 +861,7 @@ def run(rank, world_size, args):
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
@ -865,11 +870,6 @@ def run(rank, world_size, args):
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model.device = device
if rank == 0:
model_avg.to(device)
model_avg.device = device
optimizer = Eve(model.parameters(), lr=params.initial_lr)
@ -990,7 +990,7 @@ def run(rank, world_size, args):
def scan_pessimistic_batches_for_oom(
model: nn.Module,
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,

View File

@ -467,5 +467,7 @@ def average_state_dict(
uniqued_names = list(uniqued.values())
for k in uniqued_names:
state_dict_1[k] *= weight_1
state_dict_1[k] += state_dict_2[k] * weight_2
state_dict_1[k] += (
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
)
state_dict_1[k] *= scaling_factor