diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py index 2bc5bfe8c..f49de2983 100644 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -55,9 +55,9 @@ Usage: --max-contexts 4 \ --max-states 8 """ -import re import argparse import logging +import re import warnings from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -72,12 +72,13 @@ from beam_search import Hypothesis, HypothesisList, get_hyps_shape from kaldifeat import Fbank, FbankOptions from lhotse import CutSet from lhotse.cut import Cut -from lstm import LOG_EPSILON, stack_states, unstack_states from local.text_normalize import text_normalize +from lstm import LOG_EPSILON, stack_states, unstack_states from stream import Stream from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -85,7 +86,6 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.decode import one_best_decoding -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -622,10 +622,10 @@ def create_streaming_feature_extractor() -> Fbank: opts.mel_opts.num_bins = 80 return Fbank(opts) + def filter_zh_en(text: str): - import re pattern = re.compile(r"([\u4e00-\u9fff])") - + chars = pattern.split(text.upper()) chars_new = [] for char in chars: @@ -634,6 +634,7 @@ def filter_zh_en(text: str): chars_new.extend(tokens) return chars_new + def decode_dataset( cuts: CutSet, model: nn.Module, @@ -954,12 +955,12 @@ def main(): text = text.strip("\n").strip("\t") c.supervisions[0].text = text_normalize(text) return c - + tal_csasr = TAL_CSASRAsrDataModule(args) dev_cuts = tal_csasr.valid_cuts() dev_cuts = dev_cuts.map(text_normalize_for_cut) - + test_cuts = tal_csasr.test_cuts() test_cuts = test_cuts.map(text_normalize_for_cut) diff --git a/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py b/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py index c6cdd7823..c67aad202 100755 --- a/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py +++ b/egs/tal_csasr/ASR/lstm_transducer_stateless3/train.py @@ -62,9 +62,9 @@ from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from lstm import RNN from local.text_normalize import text_normalize from local.tokenize_with_bpe_model import tokenize_by_bpe_model +from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor @@ -108,7 +108,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): default=512, help="Encoder output dimesion.", ) - + parser.add_argument( "--decoder-dim", type=int, @@ -156,12 +156,12 @@ def add_model_arguments(parser: argparse.ArgumentParser): `grad_norm_threshold * median`, where `median` is the median value of gradient norms of all elememts in batch.""", ) - + parser.add_argument( "--is-pnnx", type=str2bool, default=False, - help="Only used when exporting model with pnnx." + help="Only used when exporting model with pnnx.", ) @@ -643,7 +643,7 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] - #import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() y = graph_compiler.texts_to_ids_with_bpe(texts) if type(y) == list: y = k2.RaggedTensor(y).to(device) @@ -805,7 +805,7 @@ def train_one_epoch( tot_loss = MetricsTracker() cur_batch_idx = params.get("cur_batch_idx", 0) - + for batch_idx, batch in enumerate(train_dl): if batch_idx < cur_batch_idx: continue @@ -1031,7 +1031,7 @@ def run(rank, world_size, args): # an utterance duration distribution for your dataset to select # the threshold return 1.0 <= c.duration <= 20.0 - + def text_normalize_for_cut(c: Cut): # Text normalize for each sample text = c.supervisions[0].text @@ -1040,7 +1040,7 @@ def run(rank, world_size, args): text = tokenize_by_bpe_model(sp, text) c.supervisions[0].text = text return c - + train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.map(text_normalize_for_cut)