mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Various fixes to support torch script. (#371)
* Various fixes to support torch script. * Add tests to ensure that the model is torch scriptable. * Update tests.
This commit is contained in:
parent
5aafbb970e
commit
f6ce135608
34
.github/workflows/test.yml
vendored
34
.github/workflows/test.yml
vendored
@ -103,11 +103,26 @@ jobs:
|
||||
cd egs/librispeech/ASR/conformer_ctc
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless2
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless3
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless4
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_stateless
|
||||
pytest -v -s
|
||||
|
||||
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
||||
cd ../transducer
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_stateless
|
||||
cd ../transducer_stateless2
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_lstm
|
||||
@ -128,11 +143,26 @@ jobs:
|
||||
cd egs/librispeech/ASR/conformer_ctc
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless2
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless3
|
||||
pytest -v -s
|
||||
|
||||
cd ../pruned_transducer_stateless4
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_stateless
|
||||
pytest -v -s
|
||||
|
||||
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
||||
cd ../transducer
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_stateless
|
||||
cd ../transducer_stateless2
|
||||
pytest -v -s
|
||||
|
||||
cd ../transducer_lstm
|
||||
|
@ -116,8 +116,6 @@ def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
assert args.jit is False, "Support torchscript will be added later"
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
@ -159,6 +157,11 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
|
@ -29,6 +29,7 @@ from decoder import Decoder
|
||||
def test_decoder():
|
||||
vocab_size = 3
|
||||
blank_id = 0
|
||||
unk_id = 2
|
||||
embedding_dim = 128
|
||||
context_size = 4
|
||||
|
||||
@ -36,6 +37,7 @@ def test_decoder():
|
||||
vocab_size=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
blank_id=blank_id,
|
||||
unk_id=unk_id,
|
||||
context_size=context_size,
|
||||
)
|
||||
N = 100
|
||||
|
50
egs/librispeech/ASR/pruned_transducer_stateless/test_model.py
Executable file
50
egs/librispeech/ASR/pruned_transducer_stateless/test_model.py
Executable file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.unk_id = 2
|
||||
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
torch.jit.script(model)
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -112,10 +112,13 @@ class Conformer(EncoderInterface):
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
|
||||
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
||||
#
|
||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
|
@ -131,8 +131,6 @@ def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
assert args.jit is False, "Support torchscript will be added later"
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
@ -191,6 +189,11 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
|
@ -212,7 +212,10 @@ class ScaledLinear(nn.Linear):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||
if self.bias is None or self.bias_scale is None:
|
||||
return None
|
||||
|
||||
return self.bias * self.bias_scale.exp()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return torch.nn.functional.linear(
|
||||
@ -255,7 +258,11 @@ class ScaledConv1d(nn.Conv1d):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||
bias = self.bias
|
||||
bias_scale = self.bias_scale
|
||||
if bias is None or bias_scale is None:
|
||||
return None
|
||||
return bias * bias_scale.exp()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
F = torch.nn.functional
|
||||
@ -269,7 +276,7 @@ class ScaledConv1d(nn.Conv1d):
|
||||
self.get_weight(),
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
_single(0),
|
||||
(0,),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
@ -319,7 +326,12 @@ class ScaledConv2d(nn.Conv2d):
|
||||
return self.weight * self.weight_scale.exp()
|
||||
|
||||
def get_bias(self):
|
||||
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
||||
# see https://github.com/pytorch/pytorch/issues/24135
|
||||
bias = self.bias
|
||||
bias_scale = self.bias_scale
|
||||
if bias is None or bias_scale is None:
|
||||
return None
|
||||
return bias * bias_scale.exp()
|
||||
|
||||
def _conv_forward(self, input, weight):
|
||||
F = torch.nn.functional
|
||||
@ -333,7 +345,7 @@ class ScaledConv2d(nn.Conv2d):
|
||||
weight,
|
||||
self.get_bias(),
|
||||
self.stride,
|
||||
_pair(0),
|
||||
(0, 0),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
@ -398,6 +410,9 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.max_abs = max_abs
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting():
|
||||
return x
|
||||
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
@ -444,6 +459,8 @@ class DoubleSwish(torch.nn.Module):
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
|
50
egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py
Executable file
50
egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py
Executable file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless2/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.unk_id = 2
|
||||
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
torch.jit.script(model)
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -132,8 +132,6 @@ def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
assert args.jit is False, "Support torchscript will be added later"
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
@ -192,6 +190,11 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
|
50
egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py
Executable file
50
egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py
Executable file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless3/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.unk_id = 2
|
||||
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
torch.jit.script(model)
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
69
egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py
Executable file
69
egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py
Executable file
@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless3/test_scaling.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from scaling import ActivationBalancer, ScaledConv1d, ScaledConv2d
|
||||
|
||||
|
||||
def test_scaled_conv1d():
|
||||
for bias in [True, False]:
|
||||
conv1d = ScaledConv1d(
|
||||
3,
|
||||
6,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
torch.jit.script(conv1d)
|
||||
|
||||
|
||||
def test_scaled_conv2d():
|
||||
for bias in [True, False]:
|
||||
conv2d = ScaledConv2d(
|
||||
in_channels=1,
|
||||
out_channels=3,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
torch.jit.script(conv2d)
|
||||
|
||||
|
||||
def test_activation_balancer():
|
||||
act = ActivationBalancer(
|
||||
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
torch.jit.script(act)
|
||||
|
||||
|
||||
def main():
|
||||
test_scaled_conv1d()
|
||||
test_scaled_conv2d()
|
||||
test_activation_balancer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
50
egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py
Executable file
50
egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py
Executable file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless4/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.unk_id = 2
|
||||
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
torch.jit.script(model)
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -94,7 +94,7 @@ class LstmEncoder(EncoderInterface):
|
||||
)
|
||||
|
||||
if False:
|
||||
# It is commented out as DPP complains that not all parameters are
|
||||
# It is commented out as DDP complains that not all parameters are
|
||||
# used. Need more checks later for the reason.
|
||||
#
|
||||
# Caution: We assume the dataloader returns utterances with
|
||||
@ -107,7 +107,7 @@ class LstmEncoder(EncoderInterface):
|
||||
)
|
||||
|
||||
packed_rnn_out, _ = self.rnn(packed_x)
|
||||
rnn_out, _ = pad_packed_sequence(packed_x, batch_first=True)
|
||||
rnn_out, _ = pad_packed_sequence(packed_rnn_out, batch_first=True)
|
||||
else:
|
||||
rnn_out, _ = self.rnn(x)
|
||||
|
||||
|
@ -97,8 +97,7 @@ class Transducer(nn.Module):
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_id = self.decoder.sos_id
|
||||
sos_y = add_sos(y, sos_id=sos_id)
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
@ -109,10 +109,12 @@ class Conformer(Transformer):
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
|
||||
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
||||
#
|
||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
@ -183,8 +183,6 @@ def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
assert args.jit is False, "Support torchscript will be added later"
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
@ -226,6 +224,11 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
|
@ -14,6 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -55,8 +57,8 @@ class Joiner(nn.Module):
|
||||
|
||||
N = encoder_out.size(0)
|
||||
|
||||
encoder_out_len = encoder_out_len.tolist()
|
||||
decoder_out_len = decoder_out_len.tolist()
|
||||
encoder_out_len: List[int] = encoder_out_len.tolist()
|
||||
decoder_out_len: List[int] = decoder_out_len.tolist()
|
||||
|
||||
encoder_out_list = [
|
||||
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
|
||||
|
49
egs/librispeech/ASR/transducer_stateless/test_model.py
Executable file
49
egs/librispeech/ASR/transducer_stateless/test_model.py
Executable file
@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./transducer_stateless/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
torch.jit.script(model)
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -115,8 +115,6 @@ def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
assert args.jit is False, "Support torchscript will be added later"
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
@ -158,6 +156,11 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
|
@ -14,6 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -30,7 +32,8 @@ class Joiner(nn.Module):
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
*unused,
|
||||
unused_encoder_out_len: Optional[torch.Tensor] = None,
|
||||
unused_decoder_out_len: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -38,10 +41,12 @@ class Joiner(nn.Module):
|
||||
Output from the encoder. Its shape is (N, T, self.input_dim).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, U, self.input_dim).
|
||||
unused:
|
||||
unused_encoder_out_len:
|
||||
This is a placeholder so that we can reuse
|
||||
transducer_stateless/beam_search.py in this folder as that
|
||||
script assumes the joiner networks accepts 4 inputs.
|
||||
unused_decoder_out_len:
|
||||
Just a placeholder.
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, U, self.output_dim).
|
||||
"""
|
||||
|
49
egs/librispeech/ASR/transducer_stateless2/test_model.py
Executable file
49
egs/librispeech/ASR/transducer_stateless2/test_model.py
Executable file
@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./transducer_stateless2/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
torch.jit.script(model)
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -184,8 +184,6 @@ def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
assert args.jit is False, "Support torchscript will be added later"
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
@ -229,6 +227,11 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
|
Loading…
x
Reference in New Issue
Block a user