mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
support export and merge weight of a LoRA zipformer
This commit is contained in:
parent
c0924f0a2f
commit
af04e7d7be
@ -31,8 +31,8 @@ dataset, you should change the argument values according to your dataset.
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
./zipformer_lora/export.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
@ -48,8 +48,8 @@ for how to use the exported models outside of icefall.
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
./zipformer_lora/export.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
@ -70,16 +70,16 @@ for how to use the exported models outside of icefall.
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
./zipformer_lora/export.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
./zipformer_lora/export.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--causal 1 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
@ -90,15 +90,15 @@ load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
To use the generated file with `zipformer/decode.py`,
|
||||
To use the generated file with `zipformer_lora/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./zipformer/decode.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
./zipformer_lora/decode.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
@ -107,7 +107,7 @@ you can do:
|
||||
|
||||
- For streaming model:
|
||||
|
||||
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
|
||||
To use the generated file with `zipformer_lora/decode.py` and `zipformer_lora/streaming_decode.py`, you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
@ -115,8 +115,8 @@ To use the generated file with `zipformer/decode.py` and `zipformer/streaming_de
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
|
||||
# simulated streaming decoding
|
||||
./zipformer/decode.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
./zipformer_lora/decode.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
@ -127,8 +127,8 @@ To use the generated file with `zipformer/decode.py` and `zipformer/streaming_de
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
# chunk-wise streaming decoding
|
||||
./zipformer/streaming_decode.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
./zipformer_lora/streaming_decode.py \
|
||||
--exp-dir ./zipformer_lora/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
@ -167,7 +167,7 @@ import k2
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from torch import Tensor, nn
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from finetune import add_model_arguments, add_finetune_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -225,7 +225,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
default="zipformer_lora/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
@ -256,6 +256,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
add_finetune_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
@ -486,6 +487,22 @@ def main():
|
||||
)
|
||||
)
|
||||
|
||||
# merge the LoRA weights
|
||||
model.eval()
|
||||
|
||||
params.use_lora = False
|
||||
base_model = get_model(params)
|
||||
|
||||
new_state_dict = {}
|
||||
state_dict = model.state_dict()
|
||||
param_names = base_model.state_dict().keys()
|
||||
for k in param_names:
|
||||
assert k in state_dict.keys()
|
||||
new_state_dict[k] = state_dict[k]
|
||||
|
||||
base_model.load_state_dict(new_state_dict, strict=True)
|
||||
|
||||
model = base_model
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
|
Loading…
x
Reference in New Issue
Block a user