diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/quantize.py b/egs/librispeech/ASR/pruned_transducer_stateless2/quantize.py new file mode 100644 index 000000000..9380f74ad --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/quantize.py @@ -0,0 +1,133 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import re + +import torch +from scaling import ScaledLinear + + +def _get_weight(self: torch.nn.Linear): + return self.weight + + +def _get_bias(self: torch.nn.Linear): + return self.bias + + +def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> torch.nn.Linear: + """Convert a ScaledLinear layer to a Linear layer. + + ScaledLinear layer is used for training. However, during inference + we only need a Linear layer. + + You will need this function when you want to do quantization since + ScaledLinear cannot be quantized by PyTorch. + + Args: + scaled_linear: + An instance of ScaledLinear. + Returns: + Return an instance of torch.nn.Linear. It satisfies + + scaled_linear(x) == linear(x) + + for any given input tensor x. + """ + if not hasattr(torch.nn.Linear, "get_weight"): + torch.nn.Linear.get_weight = _get_weight + torch.nn.Linear.get_bias = _get_bias + + assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear) + weight = scaled_linear.get_weight() + bias = scaled_linear.get_bias() + has_bias = bias is not None + + linear = torch.nn.Linear( + in_features=scaled_linear.in_features, + out_features=scaled_linear.out_features, + bias=has_bias, + device=weight.device, + ) + linear.weight.data.copy_(weight) + + if has_bias: + linear.bias.data.copy_(bias) + + return linear + + +def convert_scaled_linear(model: torch.nn.Module, inplace: bool = False): + """Convert **all** ScaledLinear layers in a model to Linear layers. + + Args: + model: + The input model to be converted. + inplace: + If True, the input model is modified **inplace**. + If False, the input model is copied and we modify the copy. + Returns: + Return the converted model. + """ + if not inplace: + model = copy.deepcopy(model) + + d = {} + excluded_patterns = r"self_attn\.(in|out)_proj" + p = re.compile(excluded_patterns) + for name, m in model.named_modules(): + if isinstance(m, ScaledLinear): + if p.search(name) is not None: + continue + d[name] = scaled_linear_to_linear(m) + + for k, v in d.items(): + if "." in k: + parent, child = k.rsplit(".", maxsplit=1) + setattr(model.get_submodule(parent), child, v) + else: + setattr(model, k, v) + + return model + + +def dynamic_quantize( + model: torch.nn.Module, + inplace: bool = False, +) -> torch.nn.Module: + """Apply post-training dynamic quantization to a given model. + + It is also known as post-training weight-only quantization. + Weight are quantized to tensors of dtype torch.qint8. + + Only nn.Linear layers are quantized at present. + + Args: + model: + The model to be quantized. + inplace: + If True, the passed model is modified inplace. + If False, the passed model is copied and we modify the copied model. + """ + converted_model = convert_scaled_linear(model) + q_model = torch.quantization.quantize_dynamic( + model=converted_model, + qconfig_spec={torch.nn.Linear}, + dtype=torch.qint8, + inplace=inplace, + ) + return q_model diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/test_quantize.py b/egs/librispeech/ASR/pruned_transducer_stateless2/test_quantize.py new file mode 100755 index 000000000..1e5b5648b --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/test_quantize.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless2/test_quantize.py +""" + +import copy +import os + +import torch +from quantize import ( + convert_scaled_linear, + dynamic_quantize, + scaled_linear_to_linear, +) +from scaling import ScaledLinear +from train import get_params, get_transducer_model + + +def get_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + + params.dynamic_chunk_training = False + params.short_chunk_size = 25 + params.num_left_chunks = 4 + params.causal_convolution = False + + model = get_transducer_model(params) + return model + + +def test_scaled_linear_to_linear(): + N = 5 + in_features = 10 + out_features = 20 + for bias in [True, False]: + scaled_linear = ScaledLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + ) + linear = scaled_linear_to_linear(scaled_linear) + x = torch.rand(N, in_features) + + y1 = scaled_linear(x) + y2 = linear(x) + assert torch.allclose(y1, y2) + + jit_scaled_linear = torch.jit.script(scaled_linear) + jit_linear = torch.jit.script(linear) + + y3 = jit_scaled_linear(x) + y4 = jit_linear(x) + + assert torch.allclose(y3, y4) + assert torch.allclose(y1, y4) + + +def test_convert_scaled_linear(): + for inplace in [False, True]: + model = get_model() + model.eval() + + orig_model = copy.deepcopy(model) + + converted_model = convert_scaled_linear(model, inplace=inplace) + + model = orig_model + + # test encoder + N = 2 + T = 100 + vocab_size = model.decoder.vocab_size + + x = torch.randn(N, T, 80, dtype=torch.float32) + x_lens = torch.full((N,), x.size(1)) + + e1, e1_lens = model.encoder(x, x_lens) + e2, e2_lens = converted_model.encoder(x, x_lens) + + assert torch.all(torch.eq(e1_lens, e2_lens)) + assert torch.allclose(e1, e2), (e1 - e2).abs().max() + + # test decoder + U = 50 + y = torch.randint(low=1, high=vocab_size - 1, size=(N, U)) + + d1 = model.decoder(y) + d2 = model.decoder(y) + + assert torch.allclose(d1, d2) + + # test simple projection + lm1 = model.simple_lm_proj(d1) + am1 = model.simple_am_proj(e1) + + lm2 = converted_model.simple_lm_proj(d2) + am2 = converted_model.simple_am_proj(e2) + + assert torch.allclose(lm1, lm2) + assert torch.allclose(am1, am2) + + # test joiner + e = torch.rand(2, 3, 4, 512) + d = torch.rand(2, 3, 4, 512) + + j1 = model.joiner(e, d) + j2 = converted_model.joiner(e, d) + assert torch.allclose(j1, j2) + + +def test_dynamic_quantize_size_comparison(): + model = get_model() + qmodel = dynamic_quantize(model) + + filename = "icefall-tmp-f32.pt" + qfilename = "icefall-tmp-qin8.pt" + torch.save(model, filename) + torch.save(qmodel, qfilename) + + float_size = os.path.getsize(filename) + int8_size = os.path.getsize(qfilename) + print("float_size:", float_size) + print("int8_size:", int8_size) + print(f"ratio: {float_size}/{int8_size}: {float_size/int8_size:.3f}") + + os.remove(filename) + os.remove(qfilename) + + +@torch.no_grad() +def main(): + test_scaled_linear_to_linear() + test_convert_scaled_linear() + test_dynamic_quantize_size_comparison() + + +if __name__ == "__main__": + torch.manual_seed(20220725) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 53ea306ff..2fb256ce0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -50,6 +50,7 @@ from pathlib import Path import sentencepiece as spm import torch +from quantize import dynamic_quantize from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -117,6 +118,18 @@ def get_parser(): """, ) + parser.add_argument( + "--quantize", + type=str2bool, + default=False, + help="""True to quantize the model before applying jit. + Used only when --jit is True. + It uses post training dynamic quantization. Only + ScaledLinear and Linear layers are quantized. Theire weights + are quantized to torch.qint8 tensors. + """, + ) + parser.add_argument( "--context-size", type=int, @@ -185,7 +198,9 @@ def main(): ) logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: @@ -209,9 +224,19 @@ def main(): # Otherwise, one of its arguments is a ragged tensor and is not # torch scriptabe. model.__class__.forward = torch.jit.ignore(model.__class__.forward) + if params.quantize: + logging.info("Quantization enabled") + model = dynamic_quantize(model) + filename = "quantized_cpu_jit.pt" + else: + logging.info("Quantization disabled") + filename = "cpu_jit.pt" + logging.info("Using torch.jit.script") model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" + + filename = params.exp_dir / filename + model.save(str(filename)) logging.info(f"Saved to {filename}") else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/quantize.py b/egs/librispeech/ASR/pruned_transducer_stateless3/quantize.py new file mode 120000 index 000000000..5ff18ada0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/quantize.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/quantize.py \ No newline at end of file