From 90bc61e9709de36ef063ab97639104b86c7c9d15 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 27 Jun 2022 19:46:55 +0800 Subject: [PATCH] Fix tests --- .../pruned_transducer_stateless/test_model.py | 26 ++++++++++ .../test_model.py | 51 +------------------ .../test_model.py | 51 +------------------ .../test_model.py | 51 +------------------ 4 files changed, 29 insertions(+), 150 deletions(-) mode change 100755 => 120000 egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py mode change 100755 => 120000 egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py mode change 100755 => 120000 egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py index 5c49025bd..1858d6bf0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py @@ -34,6 +34,31 @@ def test_model(): params.context_size = 2 params.unk_id = 2 + params.dynamic_chunk_training = False + params.short_chunk_size = 25 + params.num_left_chunks = 4 + params.causal_convolution = False + + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + torch.jit.script(model) + + +def test_model_streaming(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + + params.dynamic_chunk_training = True + params.short_chunk_size = 25 + params.num_left_chunks = 4 + params.causal_convolution = True + model = get_transducer_model(params) num_param = sum([p.numel() for p in model.parameters()]) @@ -44,6 +69,7 @@ def test_model(): def main(): test_model() + test_model_streaming() if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py deleted file mode 100755 index 9d5c6376d..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -To run this file, do: - - cd icefall/egs/librispeech/ASR - python ./pruned_transducer_stateless2/test_model.py -""" - -import torch -from train import get_params, get_transducer_model - - -def test_model(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.unk_id = 2 - - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - torch.jit.script(model) - - -def main(): - test_model() - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py new file mode 120000 index 000000000..4196e587c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/test_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py deleted file mode 100755 index 9a060c5fb..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -To run this file, do: - - cd icefall/egs/librispeech/ASR - python ./pruned_transducer_stateless3/test_model.py -""" - -import torch -from train import get_params, get_transducer_model - - -def test_model(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.unk_id = 2 - - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - torch.jit.script(model) - - -def main(): - test_model() - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py new file mode 120000 index 000000000..4196e587c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/test_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py deleted file mode 100755 index b1832d0ec..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -To run this file, do: - - cd icefall/egs/librispeech/ASR - python ./pruned_transducer_stateless4/test_model.py -""" - -import torch -from train import get_params, get_transducer_model - - -def test_model(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.unk_id = 2 - - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - torch.jit.script(model) - - -def main(): - test_model() - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py new file mode 120000 index 000000000..4196e587c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/test_model.py \ No newline at end of file