diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f2c63a3b8..fce14c460 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -103,11 +103,26 @@ jobs: cd egs/librispeech/ASR/conformer_ctc pytest -v -s + cd ../pruned_transducer_stateless + pytest -v -s + + cd ../pruned_transducer_stateless2 + pytest -v -s + + cd ../pruned_transducer_stateless3 + pytest -v -s + + cd ../pruned_transducer_stateless4 + pytest -v -s + + cd ../transducer_stateless + pytest -v -s + if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then cd ../transducer pytest -v -s - cd ../transducer_stateless + cd ../transducer_stateless2 pytest -v -s cd ../transducer_lstm @@ -128,11 +143,26 @@ jobs: cd egs/librispeech/ASR/conformer_ctc pytest -v -s + cd ../pruned_transducer_stateless + pytest -v -s + + cd ../pruned_transducer_stateless2 + pytest -v -s + + cd ../pruned_transducer_stateless3 + pytest -v -s + + cd ../pruned_transducer_stateless4 + pytest -v -s + + cd ../transducer_stateless + pytest -v -s + if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then cd ../transducer pytest -v -s - cd ../transducer_stateless + cd ../transducer_stateless2 pytest -v -s cd ../transducer_lstm diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index 7d2a07817..a4210831c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -116,8 +116,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -159,6 +157,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/test_decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/test_decoder.py index 937d55c2a..36712018d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/test_decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/test_decoder.py @@ -29,6 +29,7 @@ from decoder import Decoder def test_decoder(): vocab_size = 3 blank_id = 0 + unk_id = 2 embedding_dim = 128 context_size = 4 @@ -36,6 +37,7 @@ def test_decoder(): vocab_size=vocab_size, embedding_dim=embedding_dim, blank_id=blank_id, + unk_id=unk_id, context_size=context_size, ) N = 100 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py new file mode 100755 index 000000000..5c49025bd --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless/test_model.py +""" + +import torch +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + torch.jit.script(model) + + +def main(): + test_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 257936b59..840d847cb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -112,10 +112,13 @@ class Conformer(EncoderInterface): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + # Caution: We assume the subsampling factor is 4! + + # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 + assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 6b3a7a9ff..cff9c7377 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -131,8 +131,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -191,6 +189,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index f89d2963e..5ee4bab98 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -212,7 +212,10 @@ class ScaledLinear(nn.Linear): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + if self.bias is None or self.bias_scale is None: + return None + + return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: return torch.nn.functional.linear( @@ -255,7 +258,11 @@ class ScaledConv1d(nn.Conv1d): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + bias = self.bias + bias_scale = self.bias_scale + if bias is None or bias_scale is None: + return None + return bias * bias_scale.exp() def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional @@ -269,7 +276,7 @@ class ScaledConv1d(nn.Conv1d): self.get_weight(), self.get_bias(), self.stride, - _single(0), + (0,), self.dilation, self.groups, ) @@ -319,7 +326,12 @@ class ScaledConv2d(nn.Conv2d): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + # see https://github.com/pytorch/pytorch/issues/24135 + bias = self.bias + bias_scale = self.bias_scale + if bias is None or bias_scale is None: + return None + return bias * bias_scale.exp() def _conv_forward(self, input, weight): F = torch.nn.functional @@ -333,7 +345,7 @@ class ScaledConv2d(nn.Conv2d): weight, self.get_bias(), self.stride, - _pair(0), + (0, 0), self.dilation, self.groups, ) @@ -398,6 +410,9 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting(): + return x + return ActivationBalancerFunction.apply( x, self.channel_dim, @@ -444,6 +459,8 @@ class DoubleSwish(torch.nn.Module): """Return double-swish activation function which is an approximation to Swish(Swish(x)), that we approximate closely with x * sigmoid(x-1). """ + if torch.jit.is_scripting(): + return x * torch.sigmoid(x - 1.0) return DoubleSwishFunction.apply(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py new file mode 100755 index 000000000..9d5c6376d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless2/test_model.py +""" + +import torch +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + torch.jit.script(model) + + +def main(): + test_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 0cdb0b957..e674fb360 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -132,8 +132,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -192,6 +190,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py new file mode 100755 index 000000000..9a060c5fb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless3/test_model.py +""" + +import torch +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + torch.jit.script(model) + + +def main(): + test_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py new file mode 100755 index 000000000..e9dfe6d5e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless3/test_scaling.py +""" + +import torch +from scaling import ActivationBalancer, ScaledConv1d, ScaledConv2d + + +def test_scaled_conv1d(): + for bias in [True, False]: + conv1d = ScaledConv1d( + 3, + 6, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + torch.jit.script(conv1d) + + +def test_scaled_conv2d(): + for bias in [True, False]: + conv2d = ScaledConv2d( + in_channels=1, + out_channels=3, + kernel_size=3, + padding=1, + bias=bias, + ) + torch.jit.script(conv2d) + + +def test_activation_balancer(): + act = ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + torch.jit.script(act) + + +def main(): + test_scaled_conv1d() + test_scaled_conv2d() + test_activation_balancer() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py new file mode 100755 index 000000000..b1832d0ec --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless4/test_model.py +""" + +import torch +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + torch.jit.script(model) + + +def main(): + test_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py index 860a84bb1..3dc992dd2 100644 --- a/egs/librispeech/ASR/transducer_lstm/encoder.py +++ b/egs/librispeech/ASR/transducer_lstm/encoder.py @@ -94,7 +94,7 @@ class LstmEncoder(EncoderInterface): ) if False: - # It is commented out as DPP complains that not all parameters are + # It is commented out as DDP complains that not all parameters are # used. Need more checks later for the reason. # # Caution: We assume the dataloader returns utterances with @@ -107,7 +107,7 @@ class LstmEncoder(EncoderInterface): ) packed_rnn_out, _ = self.rnn(packed_x) - rnn_out, _ = pad_packed_sequence(packed_x, batch_first=True) + rnn_out, _ = pad_packed_sequence(packed_rnn_out, batch_first=True) else: rnn_out, _ = self.rnn(x) diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index 31843b60e..e37558a98 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -97,8 +97,7 @@ class Transducer(nn.Module): y_lens = row_splits[1:] - row_splits[:-1] blank_id = self.decoder.blank_id - sos_id = self.decoder.sos_id - sos_y = add_sos(y, sos_id=sos_id) + sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y_padded.to(torch.int64) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 488c82386..51f13b73f 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -109,10 +109,12 @@ class Conformer(Transformer): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + # Caution: We assume the subsampling factor is 4! + + # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py index 5687260df..8bd0bdea1 100755 --- a/egs/librispeech/ASR/transducer_stateless/export.py +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -183,8 +183,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -226,6 +224,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index b0ba7fd83..93cccbd8c 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + import torch import torch.nn as nn @@ -55,8 +57,8 @@ class Joiner(nn.Module): N = encoder_out.size(0) - encoder_out_len = encoder_out_len.tolist() - decoder_out_len = decoder_out_len.tolist() + encoder_out_len: List[int] = encoder_out_len.tolist() + decoder_out_len: List[int] = decoder_out_len.tolist() encoder_out_list = [ encoder_out[i, : encoder_out_len[i], :] for i in range(N) diff --git a/egs/librispeech/ASR/transducer_stateless/test_model.py b/egs/librispeech/ASR/transducer_stateless/test_model.py new file mode 100755 index 000000000..9e500f477 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/test_model.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_stateless/test_model.py +""" + +import torch +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + torch.jit.script(model) + + +def main(): + test_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py index 7a68f69ff..57c1a6094 100755 --- a/egs/librispeech/ASR/transducer_stateless2/export.py +++ b/egs/librispeech/ASR/transducer_stateless2/export.py @@ -115,8 +115,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -158,6 +156,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/librispeech/ASR/transducer_stateless2/joiner.py b/egs/librispeech/ASR/transducer_stateless2/joiner.py index 765f0be8b..4882f9971 100644 --- a/egs/librispeech/ASR/transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless2/joiner.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch import torch.nn as nn @@ -30,7 +32,8 @@ class Joiner(nn.Module): self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, - *unused, + unused_encoder_out_len: Optional[torch.Tensor] = None, + unused_decoder_out_len: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -38,10 +41,12 @@ class Joiner(nn.Module): Output from the encoder. Its shape is (N, T, self.input_dim). decoder_out: Output from the decoder. Its shape is (N, U, self.input_dim). - unused: + unused_encoder_out_len: This is a placeholder so that we can reuse transducer_stateless/beam_search.py in this folder as that script assumes the joiner networks accepts 4 inputs. + unused_decoder_out_len: + Just a placeholder. Returns: Return a tensor of shape (N, T, U, self.output_dim). """ diff --git a/egs/librispeech/ASR/transducer_stateless2/test_model.py b/egs/librispeech/ASR/transducer_stateless2/test_model.py new file mode 100755 index 000000000..bd2230a45 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/test_model.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_stateless2/test_model.py +""" + +import torch +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + torch.jit.script(model) + + +def main(): + test_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py index 7d14d011d..b6b69d932 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py @@ -184,8 +184,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -229,6 +227,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt"