From d64cdf6d4f309be0c81c2a47eb1580c630fa165a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 26 Jan 2024 17:35:32 +0800 Subject: [PATCH] Fix export for stateless5 --- .../pruned_transducer_stateless5/export.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py index cb541070e..5ff1f4a3b 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py @@ -20,7 +20,7 @@ Usage for offline: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp_L_offline \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 4 \ --avg 1 @@ -28,7 +28,7 @@ It will generate a file exp_dir/pretrained.pt for offline ASR. ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp_L_offline \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 4 \ --avg 1 \ --jit True @@ -38,7 +38,7 @@ It will generate a file exp_dir/cpu_jit.pt for offline ASR. Usage for streaming: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 7 \ --avg 1 @@ -46,7 +46,7 @@ It will generate a file exp_dir/pretrained.pt for streaming ASR. ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 7 \ --avg 1 \ --jit True @@ -73,13 +73,13 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -114,10 +114,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -152,10 +152,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params)