From 952abee88c59a8bcaa07ef20f3dedb2aabca3885 Mon Sep 17 00:00:00 2001 From: zzasdf <15218404468@163.com> Date: Tue, 19 Mar 2024 17:29:34 +0800 Subject: [PATCH 1/2] add checkpoint convert script --- .../local/convert_checkpoint_from_fairseq.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py diff --git a/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py b/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py new file mode 100644 index 000000000..f06adb9e5 --- /dev/null +++ b/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py @@ -0,0 +1,17 @@ +# simple script to convert a fairseq checkpoint into pytorch parameter state dict +import torch +from collections import OrderedDict +from argparse import ArgumentParser + +parser = ArgumentParser() +parser.add_argument("--src") +parser.add_argument("--tgt") + +args = parser.parse_args() +src = args.src +tgt = args.tgt + +old_checkpoint = torch.load(src) +new_checkpoint = OrderedDict() +new_checkpoint['model'] = old_checkpoint['model'] +torch.save(new_checkpoint, tgt) From ac73f60f5f8db30edaa83b31d45e7e2cdd68eae4 Mon Sep 17 00:00:00 2001 From: zzasdf <15218404468@163.com> Date: Tue, 19 Mar 2024 17:34:06 +0800 Subject: [PATCH 2/2] format --- .../SSL/local/convert_checkpoint_from_fairseq.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py b/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py index f06adb9e5..4212cd9c6 100644 --- a/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py +++ b/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py @@ -1,7 +1,8 @@ # simple script to convert a fairseq checkpoint into pytorch parameter state dict -import torch -from collections import OrderedDict from argparse import ArgumentParser +from collections import OrderedDict + +import torch parser = ArgumentParser() parser.add_argument("--src") @@ -13,5 +14,5 @@ tgt = args.tgt old_checkpoint = torch.load(src) new_checkpoint = OrderedDict() -new_checkpoint['model'] = old_checkpoint['model'] +new_checkpoint["model"] = old_checkpoint["model"] torch.save(new_checkpoint, tgt)