Various fixes to support torch script. (#371)

* Various fixes to support torch script.

* Add tests to ensure that the model is torch scriptable.

* Update tests.
This commit is contained in:
Fangjun Kuang 2022-05-16 21:46:59 +08:00 committed by GitHub
parent 5aafbb970e
commit f6ce135608
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 480 additions and 35 deletions

View File

@ -103,11 +103,26 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc
pytest -v -s
cd ../pruned_transducer_stateless
pytest -v -s
cd ../pruned_transducer_stateless2
pytest -v -s
cd ../pruned_transducer_stateless3
pytest -v -s
cd ../pruned_transducer_stateless4
pytest -v -s
cd ../transducer_stateless
pytest -v -s
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer
pytest -v -s
cd ../transducer_stateless
cd ../transducer_stateless2
pytest -v -s
cd ../transducer_lstm
@ -128,11 +143,26 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc
pytest -v -s
cd ../pruned_transducer_stateless
pytest -v -s
cd ../pruned_transducer_stateless2
pytest -v -s
cd ../pruned_transducer_stateless3
pytest -v -s
cd ../pruned_transducer_stateless4
pytest -v -s
cd ../transducer_stateless
pytest -v -s
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer
pytest -v -s
cd ../transducer_stateless
cd ../transducer_stateless2
pytest -v -s
cd ../transducer_lstm

View File

@ -116,8 +116,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
@ -159,6 +157,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -29,6 +29,7 @@ from decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
unk_id = 2
embedding_dim = 128
context_size = 4
@ -36,6 +37,7 @@ def test_decoder():
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
unk_id=unk_id,
context_size=context_size,
)
N = 100

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -112,10 +112,13 @@ class Conformer(EncoderInterface):
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
# Caution: We assume the subsampling factor is 4!
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)

View File

@ -131,8 +131,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
@ -191,6 +189,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -212,7 +212,10 @@ class ScaledLinear(nn.Linear):
return self.weight * self.weight_scale.exp()
def get_bias(self):
return None if self.bias is None else self.bias * self.bias_scale.exp()
if self.bias is None or self.bias_scale is None:
return None
return self.bias * self.bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
return torch.nn.functional.linear(
@ -255,7 +258,11 @@ class ScaledConv1d(nn.Conv1d):
return self.weight * self.weight_scale.exp()
def get_bias(self):
return None if self.bias is None else self.bias * self.bias_scale.exp()
bias = self.bias
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
def forward(self, input: Tensor) -> Tensor:
F = torch.nn.functional
@ -269,7 +276,7 @@ class ScaledConv1d(nn.Conv1d):
self.get_weight(),
self.get_bias(),
self.stride,
_single(0),
(0,),
self.dilation,
self.groups,
)
@ -319,7 +326,12 @@ class ScaledConv2d(nn.Conv2d):
return self.weight * self.weight_scale.exp()
def get_bias(self):
return None if self.bias is None else self.bias * self.bias_scale.exp()
# see https://github.com/pytorch/pytorch/issues/24135
bias = self.bias
bias_scale = self.bias_scale
if bias is None or bias_scale is None:
return None
return bias * bias_scale.exp()
def _conv_forward(self, input, weight):
F = torch.nn.functional
@ -333,7 +345,7 @@ class ScaledConv2d(nn.Conv2d):
weight,
self.get_bias(),
self.stride,
_pair(0),
(0, 0),
self.dilation,
self.groups,
)
@ -398,6 +410,9 @@ class ActivationBalancer(torch.nn.Module):
self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting():
return x
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
@ -444,6 +459,8 @@ class DoubleSwish(torch.nn.Module):
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1).
"""
if torch.jit.is_scripting():
return x * torch.sigmoid(x - 1.0)
return DoubleSwishFunction.apply(x)

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless2/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -132,8 +132,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
@ -192,6 +190,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless3/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,69 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless3/test_scaling.py
"""
import torch
from scaling import ActivationBalancer, ScaledConv1d, ScaledConv2d
def test_scaled_conv1d():
for bias in [True, False]:
conv1d = ScaledConv1d(
3,
6,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
torch.jit.script(conv1d)
def test_scaled_conv2d():
for bias in [True, False]:
conv2d = ScaledConv2d(
in_channels=1,
out_channels=3,
kernel_size=3,
padding=1,
bias=bias,
)
torch.jit.script(conv2d)
def test_activation_balancer():
act = ActivationBalancer(
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
)
torch.jit.script(act)
def main():
test_scaled_conv1d()
test_scaled_conv2d()
test_activation_balancer()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless4/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -94,7 +94,7 @@ class LstmEncoder(EncoderInterface):
)
if False:
# It is commented out as DPP complains that not all parameters are
# It is commented out as DDP complains that not all parameters are
# used. Need more checks later for the reason.
#
# Caution: We assume the dataloader returns utterances with
@ -107,7 +107,7 @@ class LstmEncoder(EncoderInterface):
)
packed_rnn_out, _ = self.rnn(packed_x)
rnn_out, _ = pad_packed_sequence(packed_x, batch_first=True)
rnn_out, _ = pad_packed_sequence(packed_rnn_out, batch_first=True)
else:
rnn_out, _ = self.rnn(x)

View File

@ -97,8 +97,7 @@ class Transducer(nn.Module):
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id
sos_y = add_sos(y, sos_id=sos_id)
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)

View File

@ -109,10 +109,12 @@ class Conformer(Transformer):
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
# Caution: We assume the subsampling factor is 4!
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)

View File

@ -183,8 +183,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
@ -226,6 +224,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
import torch.nn as nn
@ -55,8 +57,8 @@ class Joiner(nn.Module):
N = encoder_out.size(0)
encoder_out_len = encoder_out_len.tolist()
decoder_out_len = decoder_out_len.tolist()
encoder_out_len: List[int] = encoder_out_len.tolist()
decoder_out_len: List[int] = decoder_out_len.tolist()
encoder_out_list = [
encoder_out[i, : encoder_out_len[i], :] for i in range(N)

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_stateless/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -115,8 +115,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
@ -158,6 +156,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
@ -30,7 +32,8 @@ class Joiner(nn.Module):
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
*unused,
unused_encoder_out_len: Optional[torch.Tensor] = None,
unused_decoder_out_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
@ -38,10 +41,12 @@ class Joiner(nn.Module):
Output from the encoder. Its shape is (N, T, self.input_dim).
decoder_out:
Output from the decoder. Its shape is (N, U, self.input_dim).
unused:
unused_encoder_out_len:
This is a placeholder so that we can reuse
transducer_stateless/beam_search.py in this folder as that
script assumes the joiner networks accepts 4 inputs.
unused_decoder_out_len:
Just a placeholder.
Returns:
Return a tensor of shape (N, T, U, self.output_dim).
"""

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_stateless2/test_model.py
"""
import torch
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -184,8 +184,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
@ -229,6 +227,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"