mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Support quantization
This commit is contained in:
parent
d99796898c
commit
2777c0b0b3
133
egs/librispeech/ASR/pruned_transducer_stateless2/quantize.py
Normal file
133
egs/librispeech/ASR/pruned_transducer_stateless2/quantize.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import re
|
||||
|
||||
import torch
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
def _get_weight(self: torch.nn.Linear):
|
||||
return self.weight
|
||||
|
||||
|
||||
def _get_bias(self: torch.nn.Linear):
|
||||
return self.bias
|
||||
|
||||
|
||||
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> torch.nn.Linear:
|
||||
"""Convert a ScaledLinear layer to a Linear layer.
|
||||
|
||||
ScaledLinear layer is used for training. However, during inference
|
||||
we only need a Linear layer.
|
||||
|
||||
You will need this function when you want to do quantization since
|
||||
ScaledLinear cannot be quantized by PyTorch.
|
||||
|
||||
Args:
|
||||
scaled_linear:
|
||||
An instance of ScaledLinear.
|
||||
Returns:
|
||||
Return an instance of torch.nn.Linear. It satisfies
|
||||
|
||||
scaled_linear(x) == linear(x)
|
||||
|
||||
for any given input tensor x.
|
||||
"""
|
||||
if not hasattr(torch.nn.Linear, "get_weight"):
|
||||
torch.nn.Linear.get_weight = _get_weight
|
||||
torch.nn.Linear.get_bias = _get_bias
|
||||
|
||||
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)
|
||||
weight = scaled_linear.get_weight()
|
||||
bias = scaled_linear.get_bias()
|
||||
has_bias = bias is not None
|
||||
|
||||
linear = torch.nn.Linear(
|
||||
in_features=scaled_linear.in_features,
|
||||
out_features=scaled_linear.out_features,
|
||||
bias=has_bias,
|
||||
device=weight.device,
|
||||
)
|
||||
linear.weight.data.copy_(weight)
|
||||
|
||||
if has_bias:
|
||||
linear.bias.data.copy_(bias)
|
||||
|
||||
return linear
|
||||
|
||||
|
||||
def convert_scaled_linear(model: torch.nn.Module, inplace: bool = False):
|
||||
"""Convert **all** ScaledLinear layers in a model to Linear layers.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The input model to be converted.
|
||||
inplace:
|
||||
If True, the input model is modified **inplace**.
|
||||
If False, the input model is copied and we modify the copy.
|
||||
Returns:
|
||||
Return the converted model.
|
||||
"""
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
d = {}
|
||||
excluded_patterns = r"self_attn\.(in|out)_proj"
|
||||
p = re.compile(excluded_patterns)
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, ScaledLinear):
|
||||
if p.search(name) is not None:
|
||||
continue
|
||||
d[name] = scaled_linear_to_linear(m)
|
||||
|
||||
for k, v in d.items():
|
||||
if "." in k:
|
||||
parent, child = k.rsplit(".", maxsplit=1)
|
||||
setattr(model.get_submodule(parent), child, v)
|
||||
else:
|
||||
setattr(model, k, v)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def dynamic_quantize(
|
||||
model: torch.nn.Module,
|
||||
inplace: bool = False,
|
||||
) -> torch.nn.Module:
|
||||
"""Apply post-training dynamic quantization to a given model.
|
||||
|
||||
It is also known as post-training weight-only quantization.
|
||||
Weight are quantized to tensors of dtype torch.qint8.
|
||||
|
||||
Only nn.Linear layers are quantized at present.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The model to be quantized.
|
||||
inplace:
|
||||
If True, the passed model is modified inplace.
|
||||
If False, the passed model is copied and we modify the copied model.
|
||||
"""
|
||||
converted_model = convert_scaled_linear(model)
|
||||
q_model = torch.quantization.quantize_dynamic(
|
||||
model=converted_model,
|
||||
qconfig_spec={torch.nn.Linear},
|
||||
dtype=torch.qint8,
|
||||
inplace=inplace,
|
||||
)
|
||||
return q_model
|
162
egs/librispeech/ASR/pruned_transducer_stateless2/test_quantize.py
Executable file
162
egs/librispeech/ASR/pruned_transducer_stateless2/test_quantize.py
Executable file
@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless2/test_quantize.py
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
|
||||
import torch
|
||||
from quantize import (
|
||||
convert_scaled_linear,
|
||||
dynamic_quantize,
|
||||
scaled_linear_to_linear,
|
||||
)
|
||||
from scaling import ScaledLinear
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.unk_id = 2
|
||||
|
||||
params.dynamic_chunk_training = False
|
||||
params.short_chunk_size = 25
|
||||
params.num_left_chunks = 4
|
||||
params.causal_convolution = False
|
||||
|
||||
model = get_transducer_model(params)
|
||||
return model
|
||||
|
||||
|
||||
def test_scaled_linear_to_linear():
|
||||
N = 5
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
for bias in [True, False]:
|
||||
scaled_linear = ScaledLinear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
)
|
||||
linear = scaled_linear_to_linear(scaled_linear)
|
||||
x = torch.rand(N, in_features)
|
||||
|
||||
y1 = scaled_linear(x)
|
||||
y2 = linear(x)
|
||||
assert torch.allclose(y1, y2)
|
||||
|
||||
jit_scaled_linear = torch.jit.script(scaled_linear)
|
||||
jit_linear = torch.jit.script(linear)
|
||||
|
||||
y3 = jit_scaled_linear(x)
|
||||
y4 = jit_linear(x)
|
||||
|
||||
assert torch.allclose(y3, y4)
|
||||
assert torch.allclose(y1, y4)
|
||||
|
||||
|
||||
def test_convert_scaled_linear():
|
||||
for inplace in [False, True]:
|
||||
model = get_model()
|
||||
model.eval()
|
||||
|
||||
orig_model = copy.deepcopy(model)
|
||||
|
||||
converted_model = convert_scaled_linear(model, inplace=inplace)
|
||||
|
||||
model = orig_model
|
||||
|
||||
# test encoder
|
||||
N = 2
|
||||
T = 100
|
||||
vocab_size = model.decoder.vocab_size
|
||||
|
||||
x = torch.randn(N, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((N,), x.size(1))
|
||||
|
||||
e1, e1_lens = model.encoder(x, x_lens)
|
||||
e2, e2_lens = converted_model.encoder(x, x_lens)
|
||||
|
||||
assert torch.all(torch.eq(e1_lens, e2_lens))
|
||||
assert torch.allclose(e1, e2), (e1 - e2).abs().max()
|
||||
|
||||
# test decoder
|
||||
U = 50
|
||||
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
|
||||
|
||||
d1 = model.decoder(y)
|
||||
d2 = model.decoder(y)
|
||||
|
||||
assert torch.allclose(d1, d2)
|
||||
|
||||
# test simple projection
|
||||
lm1 = model.simple_lm_proj(d1)
|
||||
am1 = model.simple_am_proj(e1)
|
||||
|
||||
lm2 = converted_model.simple_lm_proj(d2)
|
||||
am2 = converted_model.simple_am_proj(e2)
|
||||
|
||||
assert torch.allclose(lm1, lm2)
|
||||
assert torch.allclose(am1, am2)
|
||||
|
||||
# test joiner
|
||||
e = torch.rand(2, 3, 4, 512)
|
||||
d = torch.rand(2, 3, 4, 512)
|
||||
|
||||
j1 = model.joiner(e, d)
|
||||
j2 = converted_model.joiner(e, d)
|
||||
assert torch.allclose(j1, j2)
|
||||
|
||||
|
||||
def test_dynamic_quantize_size_comparison():
|
||||
model = get_model()
|
||||
qmodel = dynamic_quantize(model)
|
||||
|
||||
filename = "icefall-tmp-f32.pt"
|
||||
qfilename = "icefall-tmp-qin8.pt"
|
||||
torch.save(model, filename)
|
||||
torch.save(qmodel, qfilename)
|
||||
|
||||
float_size = os.path.getsize(filename)
|
||||
int8_size = os.path.getsize(qfilename)
|
||||
print("float_size:", float_size)
|
||||
print("int8_size:", int8_size)
|
||||
print(f"ratio: {float_size}/{int8_size}: {float_size/int8_size:.3f}")
|
||||
|
||||
os.remove(filename)
|
||||
os.remove(qfilename)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
test_scaled_linear_to_linear()
|
||||
test_convert_scaled_linear()
|
||||
test_dynamic_quantize_size_comparison()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220725)
|
||||
main()
|
@ -50,6 +50,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from quantize import dynamic_quantize
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -117,6 +118,18 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to quantize the model before applying jit.
|
||||
Used only when --jit is True.
|
||||
It uses post training dynamic quantization. Only
|
||||
ScaledLinear and Linear layers are quantized. Theire weights
|
||||
are quantized to torch.qint8 tensors.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
@ -185,7 +198,9 @@ def main():
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
@ -209,9 +224,19 @@ def main():
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
if params.quantize:
|
||||
logging.info("Quantization enabled")
|
||||
model = dynamic_quantize(model)
|
||||
filename = "quantized_cpu_jit.pt"
|
||||
else:
|
||||
logging.info("Quantization disabled")
|
||||
filename = "cpu_jit.pt"
|
||||
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
|
||||
filename = params.exp_dir / filename
|
||||
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
|
1
egs/librispeech/ASR/pruned_transducer_stateless3/quantize.py
Symbolic link
1
egs/librispeech/ASR/pruned_transducer_stateless3/quantize.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/quantize.py
|
Loading…
x
Reference in New Issue
Block a user