From 7cf3ea8b33d645a88fea7ffbb44c9a09721d39b8 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 26 Jan 2024 12:17:48 +0800 Subject: [PATCH] fix export for aidatatang_200zh --- .../pruned_transducer_stateless2/export.py | 25 ++++++++++++------- .../ASR/pruned_transducer_stateless2/lstmp.py | 1 + .../scaling_converter.py | 1 + 3 files changed, 18 insertions(+), 9 deletions(-) mode change 100644 => 100755 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py create mode 120000 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py create mode 120000 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py old mode 100644 new mode 100755 index e348f7b2b..5179bfa1c --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors @@ -20,7 +21,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 29 \ --avg 19 @@ -45,12 +46,13 @@ import argparse import logging from pathlib import Path +import k2 import torch +from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -85,10 +87,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -122,10 +124,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -152,6 +158,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py new file mode 120000 index 000000000..b82e115fc --- /dev/null +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file