From f572e149a94c859f967e11a63645400016e81859 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 30 Jul 2022 21:17:31 +0800 Subject: [PATCH] Decode with exported models. --- ...pruned-transducer-stateless3-2022-05-13.sh | 41 +++ .../pruned_transducer_stateless2/decoder.py | 16 +- .../pruned_transducer_stateless2/joiner.py | 8 +- .../pruned_transducer_stateless3/export.py | 182 +++++++++- .../jit_pretrained.py | 338 ++++++++++++++++++ .../ncnn_pretrained.py | 161 +++++++++ .../onnx_pretrained.py | 337 +++++++++++++++++ .../pretrained.py | 11 +- .../scaling_converter.py | 189 ++++++++++ .../test_scaling_converter.py | 201 +++++++++++ 10 files changed, 1472 insertions(+), 12 deletions(-) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/ncnn_pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index b992515bb..00a6e2553 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -42,15 +42,56 @@ log "Export to torchscript model" --avg 1 \ --jit 1 +./pruned_transducer_stateless3/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --jit-trace 1 + ls -lh $repo/exp/*.onnx ls -lh $repo/exp/*.pt +log "Decode with ONNX models" + ./pruned_transducer_stateless3/onnx_check.py \ --jit-filename $repo/exp/cpu_jit.pt \ --onnx-encoder-filename $repo/exp/encoder.onnx \ --onnx-decoder-filename $repo/exp/decoder.onnx \ --onnx-joiner-filename $repo/exp/joiner.onnx +./pruned_transducer_stateless3/onnx_pretrained.py \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder.onnx \ + --decoder-model-filename $repo/exp/decoder.onnx \ + --joiner-model-filename $repo/exp/joiner.onnx \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Decode with models exported by torch.jit.trace()" + +./pruned_transducer_stateless3/jit_pretrained.py \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless3/jit_pretrained.py \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_script.pt \ + --decoder-model-filename $repo/exp/decoder_jit_script.pt \ + --joiner-model-filename $repo/exp/joiner_jit_script.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for sym in 1 2 3; do log "Greedy search with --max-sym-per-frame $sym" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 1ddfce034..32252b64f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + import torch import torch.nn as nn import torch.nn.functional as F @@ -77,7 +79,9 @@ class Decoder(nn.Module): # It is to support torch script self.conv = nn.Identity() - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + def forward( + self, y: torch.Tensor, need_pad: Union[bool, torch.Tensor] = True + ) -> torch.Tensor: """ Args: y: @@ -88,18 +92,24 @@ class Decoder(nn.Module): Returns: Return a tensor of shape (N, U, decoder_dim). """ + if isinstance(need_pad, torch.Tensor): + # This if for torch.jit.trace(), which cannot handle the case + # when the input argument is not a tensor. + need_pad = bool(need_pad) + y = y.to(torch.int64) embedding_out = self.embedding(y) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: + if need_pad: embedding_out = F.pad( embedding_out, pad=(self.context_size - 1, 0) ) else: # During inference time, there is no need to do extra padding # as we only need one output - assert embedding_out.size(-1) == self.context_size + if not torch.jit.is_tracing(): + assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) embedding_out = F.relu(embedding_out) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index b916addf0..b2d6ed0f2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -52,10 +52,10 @@ class Joiner(nn.Module): Returns: Return a tensor of shape (N, T, s_range, C). """ - - assert encoder_out.ndim == decoder_out.ndim - assert encoder_out.ndim in (2, 4) - assert encoder_out.shape == decoder_out.shape + if not torch.jit.is_tracing(): + assert encoder_out.ndim == decoder_out.ndim + assert encoder_out.ndim in (2, 4) + assert encoder_out.shape == decoder_out.shape if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index d805dc825..552155d0d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -21,7 +21,7 @@ """ Usage: -(1) Export to torchscript model +(1) Export to torchscript model using torch.jit.script() ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ @@ -36,7 +36,23 @@ load it by `torch.jit.load("cpu_jit.pt")`. Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python are on CPU. You can use `to("cuda")` to move them to a CUDA device. -(2) Export to ONNX format +It will also generates 3 other files: `encoder_jit_script.pt`, +`decoder_jit_script.pt`, and `joiner_jit_script.pt`. + +(2) Export to torchscript model using torch.jit.trace() + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit-trace 1 + +It will generates 3 files: `encoder_jit_trace.pt`, +`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. + + +(3) Export to ONNX format ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ @@ -53,7 +69,7 @@ Check `onnx_check.py` for how to use them. - joiner.onnx -(3) Export `model.state_dict()` +(4) Export `model.state_dict()` ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ @@ -78,6 +94,8 @@ you can do: --max-duration 600 \ --decoding-method greedy_search \ --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. """ import argparse @@ -87,6 +105,7 @@ from pathlib import Path import sentencepiece as spm import torch import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -154,6 +173,14 @@ def get_parser(): """, ) + parser.add_argument( + "--jit-trace", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace. + """, + ) + parser.add_argument( "--onnx", type=str2bool, @@ -189,6 +216,128 @@ def get_parser(): return parser +def export_encoder_model_jit_script( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.script() + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + script_model = torch.jit.script(encoder_model) + script_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_script( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.script() + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + script_model = torch.jit.script(decoder_model) + script_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_script( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + """ + script_model = torch.jit.script(joiner_model) + script_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +def export_encoder_model_jit_trace( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + traced_model = torch.jit.trace(encoder_model, (x, x_lens)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + def export_encoder_model_onnx( encoder_model: nn.Module, encoder_filename: str, @@ -262,6 +411,8 @@ def export_decoder_model_onnx( - decoder_out: a torch.float32 tensor of shape (N, 1, C) + Note: The argument need_pad is fixed to False. + Args: decoder_model: The decoder model to be exported. @@ -399,6 +550,7 @@ def main(): model.to("cpu") model.eval() + convert_scaled_to_non_scaled(model, inplace=True) if params.onnx is True: opset_version = 11 @@ -424,6 +576,7 @@ def main(): opset_version=opset_version, ) elif params.jit is True: + logging.info("Using torch.jit.script()") # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not @@ -434,8 +587,29 @@ def main(): filename = params.exp_dir / "cpu_jit.pt" model.save(str(filename)) logging.info(f"Saved to {filename}") + + # Also export encoder/decoder/joiner separately + encoder_filename = params.exp_dir / "encoder_jit_script.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_script.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_script.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + elif params.jit_trace is True: + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) else: - logging.info("Not using torch.jit.script") + logging.info("Not using torchscript") # Save it using a format so that it can be loaded # by :func:`load_checkpoint` filename = params.exp_dir / "pretrained.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py new file mode 100755 index 000000000..162f8c7db --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, either exported by `torch.jit.trace()` +or by `torch.jit.script()`, and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit-trace 1 + +or + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless3/jit_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav + +or + +./pruned_transducer_stateless3/jit_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder_jit_script.pt \ + --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder_jit_script.pt \ + --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner_jit_script.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/ncnn_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/ncnn_pretrained.py new file mode 100755 index 000000000..86afe3381 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/ncnn_pretrained.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ncnn models and uses them to decode waves. + +./pruned_transducer_stateless3/jit_pretrained.py \ + --model-dir /path/to/ncnn/model_dir + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav + +We assume there exist following files in the given `model_dir`: + + - encoder_jit_trace.ncnn.param + - encoder_jit_trace.ncnn.bin + - decoder_jit_trace.ncnn.param + - decoder_jit_trace.ncnn.bin + - joiner_jit_trace.ncnn.param + - joiner_jit_trace.ncnn.bin +""" + +import argparse +import logging +from pathlib import Path +from typing import List + +import ncnn +import torch +import torchaudio + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-dir", + type=str, + required=True, + help="Path to the ncnn models directory. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model_dir = Path(args.model_dir) + encoder_param = model_dir / "encoder_jit_trace.ncnn.param" + encoder_bin = model_dir / "encoder_jit_trace.ncnn.bin" + + decoder_param = model_dir / "decoder_jit_trace.ncnn.param" + decoder_bin = model_dir / "decoder_jit_trace.ncnn.bin" + + joiner_param = model_dir / "joiner_jit_trace.ncnn.param" + joiner_bin = model_dir / "joiner_jit_trace.ncnn.bin" + + assert encoder_param.is_file() + assert encoder_bin.is_file() + + assert decoder_param.is_file() + assert decoder_bin.is_file() + + assert joiner_param.is_file() + assert joiner_bin.is_file() + + encoder = ncnn.Net() + decoder = ncnn.Net() + joiner = ncnn.Net() + + # encoder.load_param(str(encoder_param)) # not working yet + # decoder.load_param(str(decoder_param)) + joiner.load_param(str(joiner_param)) + + encoder.clear() + decoder.clear() + joiner.clear() + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py new file mode 100755 index 000000000..ebfae9d5f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +Usage of this script: + +./pruned_transducer_stateless3/jit_trace_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ + --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ + --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import numpy as np +import onnxruntime as ort +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: ort.InferenceSession, + joiner: ort.InferenceSession, + encoder_out: np.ndarray, + encoder_out_lens: np.ndarray, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + encoder_out = torch.from_numpy(encoder_out) + encoder_out_lens = torch.from_numpy(encoder_out_lens) + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input_nodes = decoder.get_inputs() + decoder_output_nodes = decoder.get_outputs() + + joiner_input_nodes = joiner.get_inputs() + joiner_output_nodes = joiner.get_outputs() + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = joiner.run( + [joiner_output_nodes[0].name], + { + joiner_input_nodes[0].name: current_encoder_out.numpy(), + joiner_input_nodes[1].name: decoder_out, + }, + )[0] + logits = torch.from_numpy(logits) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + encoder = ort.InferenceSession( + args.encoder_model_filename, + sess_options=session_opts, + ) + + decoder = ort.InferenceSession( + args.decoder_model_filename, + sess_options=session_opts, + ) + + joiner = ort.InferenceSession( + args.joiner_model_filename, + sess_options=session_opts, + ) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + + encoder_input_nodes = encoder.get_inputs() + encoder_out_nodes = encoder.get_outputs() + encoder_out, encoder_out_lens = encoder.run( + [encoder_out_nodes[0].name, encoder_out_nodes[1].name], + { + encoder_input_nodes[0].name: features.numpy(), + encoder_input_nodes[1].name: feature_lengths.numpy(), + }, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 8b0389bc9..c15d65ded 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -15,7 +15,16 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Usage: +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: (1) greedy search ./pruned_transducer_stateless3/pretrained.py \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py new file mode 100644 index 000000000..c810e36e6 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -0,0 +1,189 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file provides functions to convert `ScaledLinear`, `ScaledConv1d`, +and `ScaledConv2d` to their non-scaled counterparts: `nn.Linear`, `nn.Conv1d`, +and `nn.Conv2d`. + +The scaled version are required only in the training time. It simplifies our +life by converting them their non-scaled version during inference time. +""" + +import copy +import re + +import torch +import torch.nn as nn +from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear + + +def _get_weight(self: torch.nn.Linear): + return self.weight + + +def _get_bias(self: torch.nn.Linear): + return self.bias + + +def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: + """Convert an instance of ScaledLinear to nn.Linear. + + Args: + scaled_linear: + The layer to be converted. + Returns: + Return a linear layer. It satisfies: + + scaled_linear(x) == linear(x) + + for any given input tensor `x`. + """ + assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear) + + # if not hasattr(torch.nn.Linear, "get_weight"): + # torch.nn.Linear.get_weight = _get_weight + # torch.nn.Linear.get_bias = _get_bias + + weight = scaled_linear.get_weight() + bias = scaled_linear.get_bias() + has_bias = bias is not None + + linear = torch.nn.Linear( + in_features=scaled_linear.in_features, + out_features=scaled_linear.out_features, + bias=True, # otherwise, it throws errors when converting to PNNX format. + device=weight.device, + ) + linear.weight.data.copy_(weight) + + if has_bias: + linear.bias.data.copy_(bias) + else: + linear.bias.data.zero_() + + return linear + + +def scaled_conv1d_to_conv1d(scaled_conv1d: ScaledConv1d) -> nn.Conv1d: + """Convert an instance of ScaledConv1d to nn.Conv1d. + + Args: + scaled_conv1d: + The layer to be converted. + Returns: + Return an instance of nn.Conv1d that has the same `forward()` behavior + of the given `scaled_conv1d`. + """ + assert isinstance(scaled_conv1d, ScaledConv1d), type(scaled_conv1d) + + weight = scaled_conv1d.get_weight() + bias = scaled_conv1d.get_bias() + has_bias = bias is not None + + conv1d = nn.Conv1d( + in_channels=scaled_conv1d.in_channels, + out_channels=scaled_conv1d.out_channels, + kernel_size=scaled_conv1d.kernel_size, + stride=scaled_conv1d.stride, + padding=scaled_conv1d.padding, + dilation=scaled_conv1d.dilation, + groups=scaled_conv1d.groups, + bias=scaled_conv1d.bias is not None, + padding_mode=scaled_conv1d.padding_mode, + ) + + conv1d.weight.data.copy_(weight) + if has_bias: + conv1d.bias.data.copy_(bias) + + return conv1d + + +def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d: + """Convert an instance of ScaledConv2d to nn.Conv2d. + + Args: + scaled_conv2d: + The layer to be converted. + Returns: + Return an instance of nn.Conv2d that has the same `forward()` behavior + of the given `scaled_conv2d`. + """ + assert isinstance(scaled_conv2d, ScaledConv2d), type(scaled_conv2d) + + weight = scaled_conv2d.get_weight() + bias = scaled_conv2d.get_bias() + has_bias = bias is not None + + conv2d = nn.Conv2d( + in_channels=scaled_conv2d.in_channels, + out_channels=scaled_conv2d.out_channels, + kernel_size=scaled_conv2d.kernel_size, + stride=scaled_conv2d.stride, + padding=scaled_conv2d.padding, + dilation=scaled_conv2d.dilation, + groups=scaled_conv2d.groups, + bias=scaled_conv2d.bias is not None, + padding_mode=scaled_conv2d.padding_mode, + ) + + conv2d.weight.data.copy_(weight) + if has_bias: + conv2d.bias.data.copy_(bias) + + return conv2d + + +def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): + """Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d` + in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`, + and `nn.Conv2d`. + + Args: + model: + The model to be converted. + inplace: + If True, the input model is modified inplace. + If False, the input model is copied and we modify the copied version. + Return: + Return a model without scaled layers. + """ + if not inplace: + model = copy.deepcopy(model) + + excluded_patterns = r"self_attn\.(in|out)_proj" + p = re.compile(excluded_patterns) + + d = {} + for name, m in model.named_modules(): + if isinstance(m, ScaledLinear): + if p.search(name) is not None: + continue + d[name] = scaled_linear_to_linear(m) + elif isinstance(m, ScaledConv1d): + d[name] = scaled_conv1d_to_conv1d(m) + elif isinstance(m, ScaledConv2d): + d[name] = scaled_conv2d_to_conv2d(m) + + for k, v in d.items(): + if "." in k: + parent, child = k.rsplit(".", maxsplit=1) + setattr(model.get_submodule(parent), child, v) + else: + setattr(model, k, v) + + return model diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py new file mode 100644 index 000000000..34a9c27f7 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless3/test_scaling_converter.py +""" + +import copy + +import torch +from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear +from scaling_converter import ( + convert_scaled_to_non_scaled, + scaled_conv1d_to_conv1d, + scaled_conv2d_to_conv2d, + scaled_linear_to_linear, +) +from train import get_params, get_transducer_model + + +def get_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + + params.dynamic_chunk_training = False + params.short_chunk_size = 25 + params.num_left_chunks = 4 + params.causal_convolution = False + + model = get_transducer_model(params, enable_giga=False) + return model + + +def test_scaled_linear_to_linear(): + N = 5 + in_features = 10 + out_features = 20 + for bias in [True, False]: + scaled_linear = ScaledLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + ) + linear = scaled_linear_to_linear(scaled_linear) + x = torch.rand(N, in_features) + + y1 = scaled_linear(x) + y2 = linear(x) + assert torch.allclose(y1, y2) + + jit_scaled_linear = torch.jit.script(scaled_linear) + jit_linear = torch.jit.script(linear) + + y3 = jit_scaled_linear(x) + y4 = jit_linear(x) + + assert torch.allclose(y3, y4) + assert torch.allclose(y1, y4) + + +def test_scaled_conv1d_to_conv1d(): + in_channels = 3 + for bias in [True, False]: + scaled_conv1d = ScaledConv1d( + in_channels, + 6, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + conv1d = scaled_conv1d_to_conv1d(scaled_conv1d) + + x = torch.rand(20, in_channels, 10) + y1 = scaled_conv1d(x) + y2 = conv1d(x) + assert torch.allclose(y1, y2) + + jit_scaled_conv1d = torch.jit.script(scaled_conv1d) + jit_conv1d = torch.jit.script(conv1d) + + y3 = jit_scaled_conv1d(x) + y4 = jit_conv1d(x) + + assert torch.allclose(y3, y4) + assert torch.allclose(y1, y4) + + +def test_scaled_conv2d_to_conv2d(): + in_channels = 1 + for bias in [True, False]: + scaled_conv2d = ScaledConv2d( + in_channels=in_channels, + out_channels=3, + kernel_size=3, + padding=1, + bias=bias, + ) + + conv2d = scaled_conv2d_to_conv2d(scaled_conv2d) + + x = torch.rand(20, in_channels, 10, 20) + y1 = scaled_conv2d(x) + y2 = conv2d(x) + assert torch.allclose(y1, y2) + + jit_scaled_conv2d = torch.jit.script(scaled_conv2d) + jit_conv2d = torch.jit.script(conv2d) + + y3 = jit_scaled_conv2d(x) + y4 = jit_conv2d(x) + + assert torch.allclose(y3, y4) + assert torch.allclose(y1, y4) + + +def test_convert_scaled_to_non_scaled(): + for inplace in [False, True]: + model = get_model() + model.eval() + + orig_model = copy.deepcopy(model) + + converted_model = convert_scaled_to_non_scaled(model, inplace=inplace) + + model = orig_model + + # test encoder + N = 2 + T = 100 + vocab_size = model.decoder.vocab_size + + x = torch.randn(N, T, 80, dtype=torch.float32) + x_lens = torch.full((N,), x.size(1)) + + e1, e1_lens = model.encoder(x, x_lens) + e2, e2_lens = converted_model.encoder(x, x_lens) + + assert torch.all(torch.eq(e1_lens, e2_lens)) + assert torch.allclose(e1, e2), (e1 - e2).abs().max() + + # test decoder + U = 50 + y = torch.randint(low=1, high=vocab_size - 1, size=(N, U)) + + d1 = model.decoder(y) + d2 = model.decoder(y) + + assert torch.allclose(d1, d2) + + # test simple projection + lm1 = model.simple_lm_proj(d1) + am1 = model.simple_am_proj(e1) + + lm2 = converted_model.simple_lm_proj(d2) + am2 = converted_model.simple_am_proj(e2) + + assert torch.allclose(lm1, lm2) + assert torch.allclose(am1, am2) + + # test joiner + e = torch.rand(2, 3, 4, 512) + d = torch.rand(2, 3, 4, 512) + + j1 = model.joiner(e, d) + j2 = converted_model.joiner(e, d) + assert torch.allclose(j1, j2) + + +@torch.no_grad() +def main(): + test_scaled_linear_to_linear() + test_scaled_conv1d_to_conv1d() + test_scaled_conv2d_to_conv2d() + test_convert_scaled_to_non_scaled() + + +if __name__ == "__main__": + torch.manual_seed(20220730) + main()