From 2d684f11bfb5cc7d75eeb17cee357bece1a51955 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 20 Jun 2022 12:30:49 +0800 Subject: [PATCH] support torch script. --- egs/aishell/ASR/README.md | 1 - egs/aishell/ASR/RESULTS.md | 77 +++- .../pruned_transducer_stateless3/export.py | 277 ++++++++++++++ .../pretrained.py | 337 ++++++++++++++++++ .../ASR/pruned_transducer_stateless3/train.py | 8 +- .../pruned_transducer_stateless2/decoder.py | 3 + .../pruned_transducer_stateless5/conformer.py | 39 +- .../pruned_transducer_stateless5/export.py | 9 +- 8 files changed, 729 insertions(+), 22 deletions(-) create mode 100755 egs/aishell/ASR/pruned_transducer_stateless3/export.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index a5d49b6be..75fc6326e 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -4,7 +4,6 @@ Please refer to for how to run models in this recipe. -# Pruned transducer stateless 3 # Transducers diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index ecc93c21b..5c9490a22 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,10 +1,85 @@ ## Results -### Aishell training result(Transducer-stateless) + +### Aishell training result(Stateless Transducer) + +#### Pruned transducer stateless 3 + +See + +[./pruned_transducer_stateless3](./pruned_transducer_stateless3) + +It uses pruned RNN-T. + +| | test | dev | comment | +|------------------------|------|------|---------------------------------------| +| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 | +| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 | +| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 | + +Training command is: + +```bash +export CUDA_VISIBLE_DEVICES="4,5,6,7" + +./pruned_transducer_stateless3/train.py \ + --exp-dir ./pruned_transducer_stateless3/exp-context-size-1 \ + --world-size 4 \ + --max-duration 200 \ + --datatang-prob 0.5 \ + --start-epoch 1 \ + --num-epochs 30 \ + --use-fp16 1 \ + --num-encoder-layers 12 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --context-size 1 \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --master-port 12356 +``` + +**Caution**: It uses `--context-size=1`. + +The tensorboard log is available at + + +The decoding command is: + +```bash +for epoch in 29; do + for avg in 5; do + for m in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless3/decode.py \ + --exp-dir ./pruned_transducer_stateless3/exp-context-size-1 \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --max-duration 600 \ + --decoding-method $m \ + --num-encoder-layers 12 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --context-size 1 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + done + done +done +``` + +Pretrained models, training logs, decoding logs, and decoding results +are available at + #### 2022-03-01 [./transducer_stateless_modified-2](./transducer_stateless_modified-2) +It uses [optimized_transducer](https://github.com/csukuangfj/optimized_transducer) +for computing RNN-T loss. + Stateless transducer + modified transducer + using [aidatatang_200zh](http://www.openslr.org/62/) as extra training data. diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py new file mode 100755 index 000000000..307895a76 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --jit 0 \ + --epoch 29 \ + --avg 5 + +It will generate a file exp_dir/pretrained-epoch-29-avg-5.pt + +To use the generated file with `pruned_transducer_stateless3/decode.py`, +you can do:: + + cd /path/to/exp_dir + ln -s pretrained-epoch-29-avg-5.pt epoch-9999.pt + + cd /path/to/egs/aishell/ASR + ./pruned_transducer_stateless3/decode.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 100 \ + --lang-dir data/lang_char +""" + +import argparse +import logging +from pathlib import Path + +import torch +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=29, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default=Path("pruned_transducer_stateless3/exp"), + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default=Path("data/lang_char"), + help="The lang dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def main(): + args = get_parser().parse_args() + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = ( + params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" + ) + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = ( + params.exp_dir + / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" + ) + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py new file mode 100755 index 000000000..5cda411bc --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + +(1) greedy search +./pruned_transducer_stateless3/pretrained.py \ + --checkpoint /path/to/pretrained.pt \ + --lang-dir /path/to/lang_char \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless3/pretrained.py \ + --checkpoint /path/to/pretrained.pt \ + --lang-dir /path/to/lang_char \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless3/pretrained.py \ + --checkpoint /path/to/pretrained.pt \ + --lang-dir /path/to/lang_char \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless3/pretrained.py \ + --checkpoint /path/to/pretrained.pt \ + --lang-dir /path/to/lang_char \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.lexicon import Lexicon + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default=Path("data/lang_char"), + help="The lang dir", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="Maximum number of symbols per frame. " + "Use only when --method is greedy_search", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lens = [f.size(0) for f in features] + feature_lens = torch.tensor(feature_lens, device=device) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) + + num_waves = encoder_out.size(0) + hyp_list = [] + logging.info(f"Using {params.method}") + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_list = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.method}" + ) + hyp_list.append(hyp) + + hyps = [] + for hyp in hyp_list: + hyps.append([lexicon.token_table[i] for i in hyp]) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index 2bc201f05..02efe94fe 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -101,14 +101,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=int, - default=36, + default=12, help="Number of conformer encoder layers..", ) parser.add_argument( "--dim-feedforward", type=int, - default=1024, + default=2048, help="Feedforward dimension of the conformer encoder layer.", ) @@ -122,7 +122,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-dim", type=int, - default=256, + default=512, help="Attention dimension in the conformer encoder layer.", ) @@ -1185,7 +1185,7 @@ def scan_pessimistic_batches_for_oom( graph_compiler=graph_compiler, batch=batch, is_training=True, - warmup=0.0, + warmup=0.0 if params.start_epoch == 1 else 1.0, ) loss.backward() optimizer.step() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index b6d94aaf1..1ddfce034 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -73,6 +73,9 @@ class Decoder(nn.Module): groups=decoder_dim, bias=False, ) + else: + # It is to support torch script + self.conv = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 6f7231f4b..bf3917df0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -117,10 +117,7 @@ class Conformer(EncoderInterface): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) @@ -293,8 +290,10 @@ class ConformerEncoder(nn.Module): ) self.num_layers = num_layers + assert len(set(aux_layers)) == len(aux_layers) + assert num_layers - 1 not in aux_layers - self.aux_layers = set(aux_layers + [num_layers - 1]) + self.aux_layers = aux_layers + [num_layers - 1] num_channels = encoder_layer.norm_final.num_channels self.combiner = RandomCombine( @@ -1154,7 +1153,7 @@ class RandomCombine(nn.Module): """ num_inputs = self.num_inputs assert len(inputs) == num_inputs - if not self.training: + if not self.training or torch.jit.is_scripting(): return inputs[-1] # Shape of weights: (*, num_inputs) @@ -1162,8 +1161,22 @@ class RandomCombine(nn.Module): num_frames = inputs[0].numel() // num_channels mod_inputs = [] - for i in range(num_inputs - 1): - mod_inputs.append(self.linear[i](inputs[i])) + + if False: + # It throws the following error for torch 1.6.0 when using + # torch script. + # + # Expected integer literal for index. ModuleList/Sequential + # indexing is only supported with integer literals. Enumeration is + # supported, e.g. 'for index, v in enumerate(self): ...': + # for i in range(num_inputs - 1): + # mod_inputs.append(self.linear[i](inputs[i])) + assert False + else: + for i, linear in enumerate(self.linear): + if i < num_inputs - 1: + mod_inputs.append(linear(inputs[i])) + mod_inputs.append(inputs[num_inputs - 1]) ndim = inputs[0].ndim @@ -1181,11 +1194,13 @@ class RandomCombine(nn.Module): # ans: (num_frames, num_channels, 1) ans = torch.matmul(stacked_inputs, weights) # ans: (*, num_channels) - ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) - if __name__ == "__main__": - # for testing only... - print("Weights = ", weights.reshape(num_frames, num_inputs)) + ans = ans.reshape(inputs[0].shape[:-1] + [num_channels]) + + # The following if causes errors for torch script in torch 1.6.0 + # if __name__ == "__main__": + # # for testing only... + # print("Weights = ", weights.reshape(num_frames, num_inputs)) return ans def _get_random_weights( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index f1269a4bd..936508900 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -146,8 +146,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -246,12 +244,15 @@ def main(): ) ) - model.eval() - model.to("cpu") model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt"