fix code style

This commit is contained in:
marcoyang 2023-02-13 17:49:32 +08:00
parent 8fb81e27cb
commit dae3a310f4
2 changed files with 16 additions and 15 deletions

View File

@ -55,9 +55,9 @@ Usage:
--max-contexts 4 \
--max-states 8
"""
import re
import argparse
import logging
import re
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -72,12 +72,13 @@ from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from lhotse.cut import Cut
from lstm import LOG_EPSILON, stack_states, unstack_states
from local.text_normalize import text_normalize
from lstm import LOG_EPSILON, stack_states, unstack_states
from stream import Stream
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -85,7 +86,6 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -622,10 +622,10 @@ def create_streaming_feature_extractor() -> Fbank:
opts.mel_opts.num_bins = 80
return Fbank(opts)
def filter_zh_en(text: str):
import re
pattern = re.compile(r"([\u4e00-\u9fff])")
chars = pattern.split(text.upper())
chars_new = []
for char in chars:
@ -634,6 +634,7 @@ def filter_zh_en(text: str):
chars_new.extend(tokens)
return chars_new
def decode_dataset(
cuts: CutSet,
model: nn.Module,
@ -954,12 +955,12 @@ def main():
text = text.strip("\n").strip("\t")
c.supervisions[0].text = text_normalize(text)
return c
tal_csasr = TAL_CSASRAsrDataModule(args)
dev_cuts = tal_csasr.valid_cuts()
dev_cuts = dev_cuts.map(text_normalize_for_cut)
test_cuts = tal_csasr.test_cuts()
test_cuts = test_cuts.map(text_normalize_for_cut)

View File

@ -62,9 +62,9 @@ from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from lstm import RNN
from local.text_normalize import text_normalize
from local.tokenize_with_bpe_model import tokenize_by_bpe_model
from lstm import RNN
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
@ -108,7 +108,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default=512,
help="Encoder output dimesion.",
)
parser.add_argument(
"--decoder-dim",
type=int,
@ -156,12 +156,12 @@ def add_model_arguments(parser: argparse.ArgumentParser):
`grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.""",
)
parser.add_argument(
"--is-pnnx",
type=str2bool,
default=False,
help="Only used when exporting model with pnnx."
help="Only used when exporting model with pnnx.",
)
@ -643,7 +643,7 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
#import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
y = graph_compiler.texts_to_ids_with_bpe(texts)
if type(y) == list:
y = k2.RaggedTensor(y).to(device)
@ -805,7 +805,7 @@ def train_one_epoch(
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
@ -1031,7 +1031,7 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
def text_normalize_for_cut(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
@ -1040,7 +1040,7 @@ def run(rank, world_size, args):
text = tokenize_by_bpe_model(sp, text)
c.supervisions[0].text = text
return c
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_cuts = train_cuts.map(text_normalize_for_cut)