mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Various fixes to support torch script. (#371)
* Various fixes to support torch script. * Add tests to ensure that the model is torch scriptable. * Update tests.
This commit is contained in:
parent
5aafbb970e
commit
f6ce135608
36
.github/workflows/test.yml
vendored
36
.github/workflows/test.yml
vendored
@ -103,11 +103,26 @@ jobs:
|
|||||||
cd egs/librispeech/ASR/conformer_ctc
|
cd egs/librispeech/ASR/conformer_ctc
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless2
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless3
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless4
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../transducer_stateless
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
||||||
cd ../transducer
|
cd ../transducer
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
cd ../transducer_stateless
|
cd ../transducer_stateless2
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
cd ../transducer_lstm
|
cd ../transducer_lstm
|
||||||
@ -128,13 +143,28 @@ jobs:
|
|||||||
cd egs/librispeech/ASR/conformer_ctc
|
cd egs/librispeech/ASR/conformer_ctc
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
cd ../pruned_transducer_stateless
|
||||||
cd ../transducer
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless2
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless3
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../pruned_transducer_stateless4
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
cd ../transducer_stateless
|
cd ../transducer_stateless
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
|
|
||||||
|
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
||||||
|
cd ../transducer
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
|
cd ../transducer_stateless2
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
cd ../transducer_lstm
|
cd ../transducer_lstm
|
||||||
pytest -v -s
|
pytest -v -s
|
||||||
fi
|
fi
|
||||||
|
@ -116,8 +116,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -159,6 +157,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -29,6 +29,7 @@ from decoder import Decoder
|
|||||||
def test_decoder():
|
def test_decoder():
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
blank_id = 0
|
blank_id = 0
|
||||||
|
unk_id = 2
|
||||||
embedding_dim = 128
|
embedding_dim = 128
|
||||||
context_size = 4
|
context_size = 4
|
||||||
|
|
||||||
@ -36,6 +37,7 @@ def test_decoder():
|
|||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
blank_id=blank_id,
|
blank_id=blank_id,
|
||||||
|
unk_id=unk_id,
|
||||||
context_size=context_size,
|
context_size=context_size,
|
||||||
)
|
)
|
||||||
N = 100
|
N = 100
|
||||||
|
50
egs/librispeech/ASR/pruned_transducer_stateless/test_model.py
Executable file
50
egs/librispeech/ASR/pruned_transducer_stateless/test_model.py
Executable file
@ -0,0 +1,50 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./pruned_transducer_stateless/test_model.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_model():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
params.unk_id = 2
|
||||||
|
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
torch.jit.script(model)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_model()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -112,10 +112,13 @@ class Conformer(EncoderInterface):
|
|||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
|
||||||
|
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
||||||
|
#
|
||||||
|
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||||
|
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
|
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
|
@ -131,8 +131,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -191,6 +189,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -212,7 +212,10 @@ class ScaledLinear(nn.Linear):
|
|||||||
return self.weight * self.weight_scale.exp()
|
return self.weight * self.weight_scale.exp()
|
||||||
|
|
||||||
def get_bias(self):
|
def get_bias(self):
|
||||||
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
if self.bias is None or self.bias_scale is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.bias * self.bias_scale.exp()
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
return torch.nn.functional.linear(
|
return torch.nn.functional.linear(
|
||||||
@ -255,7 +258,11 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
return self.weight * self.weight_scale.exp()
|
return self.weight * self.weight_scale.exp()
|
||||||
|
|
||||||
def get_bias(self):
|
def get_bias(self):
|
||||||
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
bias = self.bias
|
||||||
|
bias_scale = self.bias_scale
|
||||||
|
if bias is None or bias_scale is None:
|
||||||
|
return None
|
||||||
|
return bias * bias_scale.exp()
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
F = torch.nn.functional
|
F = torch.nn.functional
|
||||||
@ -269,7 +276,7 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
self.get_weight(),
|
self.get_weight(),
|
||||||
self.get_bias(),
|
self.get_bias(),
|
||||||
self.stride,
|
self.stride,
|
||||||
_single(0),
|
(0,),
|
||||||
self.dilation,
|
self.dilation,
|
||||||
self.groups,
|
self.groups,
|
||||||
)
|
)
|
||||||
@ -319,7 +326,12 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
return self.weight * self.weight_scale.exp()
|
return self.weight * self.weight_scale.exp()
|
||||||
|
|
||||||
def get_bias(self):
|
def get_bias(self):
|
||||||
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
# see https://github.com/pytorch/pytorch/issues/24135
|
||||||
|
bias = self.bias
|
||||||
|
bias_scale = self.bias_scale
|
||||||
|
if bias is None or bias_scale is None:
|
||||||
|
return None
|
||||||
|
return bias * bias_scale.exp()
|
||||||
|
|
||||||
def _conv_forward(self, input, weight):
|
def _conv_forward(self, input, weight):
|
||||||
F = torch.nn.functional
|
F = torch.nn.functional
|
||||||
@ -333,7 +345,7 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
weight,
|
weight,
|
||||||
self.get_bias(),
|
self.get_bias(),
|
||||||
self.stride,
|
self.stride,
|
||||||
_pair(0),
|
(0, 0),
|
||||||
self.dilation,
|
self.dilation,
|
||||||
self.groups,
|
self.groups,
|
||||||
)
|
)
|
||||||
@ -398,6 +410,9 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
self.max_abs = max_abs
|
self.max_abs = max_abs
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
return x
|
||||||
|
|
||||||
return ActivationBalancerFunction.apply(
|
return ActivationBalancerFunction.apply(
|
||||||
x,
|
x,
|
||||||
self.channel_dim,
|
self.channel_dim,
|
||||||
@ -444,6 +459,8 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||||
that we approximate closely with x * sigmoid(x-1).
|
that we approximate closely with x * sigmoid(x-1).
|
||||||
"""
|
"""
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
return x * torch.sigmoid(x - 1.0)
|
||||||
return DoubleSwishFunction.apply(x)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
50
egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py
Executable file
50
egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py
Executable file
@ -0,0 +1,50 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./pruned_transducer_stateless2/test_model.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_model():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
params.unk_id = 2
|
||||||
|
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
torch.jit.script(model)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_model()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -132,8 +132,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -192,6 +190,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
50
egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py
Executable file
50
egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py
Executable file
@ -0,0 +1,50 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./pruned_transducer_stateless3/test_model.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_model():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
params.unk_id = 2
|
||||||
|
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
torch.jit.script(model)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_model()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
69
egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py
Executable file
69
egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py
Executable file
@ -0,0 +1,69 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./pruned_transducer_stateless3/test_scaling.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling import ActivationBalancer, ScaledConv1d, ScaledConv2d
|
||||||
|
|
||||||
|
|
||||||
|
def test_scaled_conv1d():
|
||||||
|
for bias in [True, False]:
|
||||||
|
conv1d = ScaledConv1d(
|
||||||
|
3,
|
||||||
|
6,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
torch.jit.script(conv1d)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scaled_conv2d():
|
||||||
|
for bias in [True, False]:
|
||||||
|
conv2d = ScaledConv2d(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=3,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
torch.jit.script(conv2d)
|
||||||
|
|
||||||
|
|
||||||
|
def test_activation_balancer():
|
||||||
|
act = ActivationBalancer(
|
||||||
|
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||||
|
)
|
||||||
|
torch.jit.script(act)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_scaled_conv1d()
|
||||||
|
test_scaled_conv2d()
|
||||||
|
test_activation_balancer()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
50
egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py
Executable file
50
egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py
Executable file
@ -0,0 +1,50 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./pruned_transducer_stateless4/test_model.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_model():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
params.unk_id = 2
|
||||||
|
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
torch.jit.script(model)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_model()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -94,7 +94,7 @@ class LstmEncoder(EncoderInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
# It is commented out as DPP complains that not all parameters are
|
# It is commented out as DDP complains that not all parameters are
|
||||||
# used. Need more checks later for the reason.
|
# used. Need more checks later for the reason.
|
||||||
#
|
#
|
||||||
# Caution: We assume the dataloader returns utterances with
|
# Caution: We assume the dataloader returns utterances with
|
||||||
@ -107,7 +107,7 @@ class LstmEncoder(EncoderInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
packed_rnn_out, _ = self.rnn(packed_x)
|
packed_rnn_out, _ = self.rnn(packed_x)
|
||||||
rnn_out, _ = pad_packed_sequence(packed_x, batch_first=True)
|
rnn_out, _ = pad_packed_sequence(packed_rnn_out, batch_first=True)
|
||||||
else:
|
else:
|
||||||
rnn_out, _ = self.rnn(x)
|
rnn_out, _ = self.rnn(x)
|
||||||
|
|
||||||
|
@ -97,8 +97,7 @@ class Transducer(nn.Module):
|
|||||||
y_lens = row_splits[1:] - row_splits[:-1]
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
blank_id = self.decoder.blank_id
|
blank_id = self.decoder.blank_id
|
||||||
sos_id = self.decoder.sos_id
|
sos_y = add_sos(y, sos_id=blank_id)
|
||||||
sos_y = add_sos(y, sos_id=sos_id)
|
|
||||||
|
|
||||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||||
|
@ -109,10 +109,12 @@ class Conformer(Transformer):
|
|||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
|
||||||
|
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
||||||
|
#
|
||||||
|
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||||
|
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
|
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
@ -183,8 +183,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -226,6 +224,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -14,6 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -55,8 +57,8 @@ class Joiner(nn.Module):
|
|||||||
|
|
||||||
N = encoder_out.size(0)
|
N = encoder_out.size(0)
|
||||||
|
|
||||||
encoder_out_len = encoder_out_len.tolist()
|
encoder_out_len: List[int] = encoder_out_len.tolist()
|
||||||
decoder_out_len = decoder_out_len.tolist()
|
decoder_out_len: List[int] = decoder_out_len.tolist()
|
||||||
|
|
||||||
encoder_out_list = [
|
encoder_out_list = [
|
||||||
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
|
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
|
||||||
|
49
egs/librispeech/ASR/transducer_stateless/test_model.py
Executable file
49
egs/librispeech/ASR/transducer_stateless/test_model.py
Executable file
@ -0,0 +1,49 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./transducer_stateless/test_model.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_model():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
torch.jit.script(model)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_model()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -115,8 +115,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -158,6 +156,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -14,6 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -30,7 +32,8 @@ class Joiner(nn.Module):
|
|||||||
self,
|
self,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
decoder_out: torch.Tensor,
|
decoder_out: torch.Tensor,
|
||||||
*unused,
|
unused_encoder_out_len: Optional[torch.Tensor] = None,
|
||||||
|
unused_decoder_out_len: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -38,10 +41,12 @@ class Joiner(nn.Module):
|
|||||||
Output from the encoder. Its shape is (N, T, self.input_dim).
|
Output from the encoder. Its shape is (N, T, self.input_dim).
|
||||||
decoder_out:
|
decoder_out:
|
||||||
Output from the decoder. Its shape is (N, U, self.input_dim).
|
Output from the decoder. Its shape is (N, U, self.input_dim).
|
||||||
unused:
|
unused_encoder_out_len:
|
||||||
This is a placeholder so that we can reuse
|
This is a placeholder so that we can reuse
|
||||||
transducer_stateless/beam_search.py in this folder as that
|
transducer_stateless/beam_search.py in this folder as that
|
||||||
script assumes the joiner networks accepts 4 inputs.
|
script assumes the joiner networks accepts 4 inputs.
|
||||||
|
unused_decoder_out_len:
|
||||||
|
Just a placeholder.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, U, self.output_dim).
|
Return a tensor of shape (N, T, U, self.output_dim).
|
||||||
"""
|
"""
|
||||||
|
49
egs/librispeech/ASR/transducer_stateless2/test_model.py
Executable file
49
egs/librispeech/ASR/transducer_stateless2/test_model.py
Executable file
@ -0,0 +1,49 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./transducer_stateless2/test_model.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_model():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
torch.jit.script(model)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_model()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -184,8 +184,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -229,6 +227,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user