Support quantization

This commit is contained in:
Fangjun Kuang 2022-07-25 18:14:21 +08:00
parent d99796898c
commit 2777c0b0b3
4 changed files with 323 additions and 2 deletions

View File

@ -0,0 +1,133 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import re
import torch
from scaling import ScaledLinear
def _get_weight(self: torch.nn.Linear):
return self.weight
def _get_bias(self: torch.nn.Linear):
return self.bias
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> torch.nn.Linear:
"""Convert a ScaledLinear layer to a Linear layer.
ScaledLinear layer is used for training. However, during inference
we only need a Linear layer.
You will need this function when you want to do quantization since
ScaledLinear cannot be quantized by PyTorch.
Args:
scaled_linear:
An instance of ScaledLinear.
Returns:
Return an instance of torch.nn.Linear. It satisfies
scaled_linear(x) == linear(x)
for any given input tensor x.
"""
if not hasattr(torch.nn.Linear, "get_weight"):
torch.nn.Linear.get_weight = _get_weight
torch.nn.Linear.get_bias = _get_bias
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)
weight = scaled_linear.get_weight()
bias = scaled_linear.get_bias()
has_bias = bias is not None
linear = torch.nn.Linear(
in_features=scaled_linear.in_features,
out_features=scaled_linear.out_features,
bias=has_bias,
device=weight.device,
)
linear.weight.data.copy_(weight)
if has_bias:
linear.bias.data.copy_(bias)
return linear
def convert_scaled_linear(model: torch.nn.Module, inplace: bool = False):
"""Convert **all** ScaledLinear layers in a model to Linear layers.
Args:
model:
The input model to be converted.
inplace:
If True, the input model is modified **inplace**.
If False, the input model is copied and we modify the copy.
Returns:
Return the converted model.
"""
if not inplace:
model = copy.deepcopy(model)
d = {}
excluded_patterns = r"self_attn\.(in|out)_proj"
p = re.compile(excluded_patterns)
for name, m in model.named_modules():
if isinstance(m, ScaledLinear):
if p.search(name) is not None:
continue
d[name] = scaled_linear_to_linear(m)
for k, v in d.items():
if "." in k:
parent, child = k.rsplit(".", maxsplit=1)
setattr(model.get_submodule(parent), child, v)
else:
setattr(model, k, v)
return model
def dynamic_quantize(
model: torch.nn.Module,
inplace: bool = False,
) -> torch.nn.Module:
"""Apply post-training dynamic quantization to a given model.
It is also known as post-training weight-only quantization.
Weight are quantized to tensors of dtype torch.qint8.
Only nn.Linear layers are quantized at present.
Args:
model:
The model to be quantized.
inplace:
If True, the passed model is modified inplace.
If False, the passed model is copied and we modify the copied model.
"""
converted_model = convert_scaled_linear(model)
q_model = torch.quantization.quantize_dynamic(
model=converted_model,
qconfig_spec={torch.nn.Linear},
dtype=torch.qint8,
inplace=inplace,
)
return q_model

View File

@ -0,0 +1,162 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless2/test_quantize.py
"""
import copy
import os
import torch
from quantize import (
convert_scaled_linear,
dynamic_quantize,
scaled_linear_to_linear,
)
from scaling import ScaledLinear
from train import get_params, get_transducer_model
def get_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.dynamic_chunk_training = False
params.short_chunk_size = 25
params.num_left_chunks = 4
params.causal_convolution = False
model = get_transducer_model(params)
return model
def test_scaled_linear_to_linear():
N = 5
in_features = 10
out_features = 20
for bias in [True, False]:
scaled_linear = ScaledLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
)
linear = scaled_linear_to_linear(scaled_linear)
x = torch.rand(N, in_features)
y1 = scaled_linear(x)
y2 = linear(x)
assert torch.allclose(y1, y2)
jit_scaled_linear = torch.jit.script(scaled_linear)
jit_linear = torch.jit.script(linear)
y3 = jit_scaled_linear(x)
y4 = jit_linear(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_convert_scaled_linear():
for inplace in [False, True]:
model = get_model()
model.eval()
orig_model = copy.deepcopy(model)
converted_model = convert_scaled_linear(model, inplace=inplace)
model = orig_model
# test encoder
N = 2
T = 100
vocab_size = model.decoder.vocab_size
x = torch.randn(N, T, 80, dtype=torch.float32)
x_lens = torch.full((N,), x.size(1))
e1, e1_lens = model.encoder(x, x_lens)
e2, e2_lens = converted_model.encoder(x, x_lens)
assert torch.all(torch.eq(e1_lens, e2_lens))
assert torch.allclose(e1, e2), (e1 - e2).abs().max()
# test decoder
U = 50
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
d1 = model.decoder(y)
d2 = model.decoder(y)
assert torch.allclose(d1, d2)
# test simple projection
lm1 = model.simple_lm_proj(d1)
am1 = model.simple_am_proj(e1)
lm2 = converted_model.simple_lm_proj(d2)
am2 = converted_model.simple_am_proj(e2)
assert torch.allclose(lm1, lm2)
assert torch.allclose(am1, am2)
# test joiner
e = torch.rand(2, 3, 4, 512)
d = torch.rand(2, 3, 4, 512)
j1 = model.joiner(e, d)
j2 = converted_model.joiner(e, d)
assert torch.allclose(j1, j2)
def test_dynamic_quantize_size_comparison():
model = get_model()
qmodel = dynamic_quantize(model)
filename = "icefall-tmp-f32.pt"
qfilename = "icefall-tmp-qin8.pt"
torch.save(model, filename)
torch.save(qmodel, qfilename)
float_size = os.path.getsize(filename)
int8_size = os.path.getsize(qfilename)
print("float_size:", float_size)
print("int8_size:", int8_size)
print(f"ratio: {float_size}/{int8_size}: {float_size/int8_size:.3f}")
os.remove(filename)
os.remove(qfilename)
@torch.no_grad()
def main():
test_scaled_linear_to_linear()
test_convert_scaled_linear()
test_dynamic_quantize_size_comparison()
if __name__ == "__main__":
torch.manual_seed(20220725)
main()

View File

@ -50,6 +50,7 @@ from pathlib import Path
import sentencepiece as spm
import torch
from quantize import dynamic_quantize
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
@ -117,6 +118,18 @@ def get_parser():
""",
)
parser.add_argument(
"--quantize",
type=str2bool,
default=False,
help="""True to quantize the model before applying jit.
Used only when --jit is True.
It uses post training dynamic quantization. Only
ScaledLinear and Linear layers are quantized. Theire weights
are quantized to torch.qint8 tensors.
""",
)
parser.add_argument(
"--context-size",
type=int,
@ -185,7 +198,9 @@ def main():
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
@ -209,9 +224,19 @@ def main():
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
if params.quantize:
logging.info("Quantization enabled")
model = dynamic_quantize(model)
filename = "quantized_cpu_jit.pt"
else:
logging.info("Quantization disabled")
filename = "cpu_jit.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
filename = params.exp_dir / filename
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/quantize.py