from local

This commit is contained in:
dohe0342 2023-05-24 20:53:52 +09:00
parent edb017e7cd
commit 520eb2f93f
3 changed files with 3 additions and 0 deletions

View File

@ -142,6 +142,9 @@ class LoRAHook():
lora_out = self.lora(input)
output = input + lora_out
def save_checkpoint(self, i, save_dir):
torch.save(self.lora.state_dict(), f"{save_dir}/lora_{i}.pt")
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):