Merge pull request #2 from zzasdf/k2ssl-util

checkpoint convert script
This commit is contained in:
Yifan Yang 2024-03-19 17:44:51 +08:00 committed by GitHub
commit 482c24eab0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -0,0 +1,18 @@
# simple script to convert a fairseq checkpoint into pytorch parameter state dict
from argparse import ArgumentParser
from collections import OrderedDict
import torch
parser = ArgumentParser()
parser.add_argument("--src")
parser.add_argument("--tgt")
args = parser.parse_args()
src = args.src
tgt = args.tgt
old_checkpoint = torch.load(src)
new_checkpoint = OrderedDict()
new_checkpoint["model"] = old_checkpoint["model"]
torch.save(new_checkpoint, tgt)