mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
fix code style
This commit is contained in:
parent
8fb81e27cb
commit
dae3a310f4
@ -55,9 +55,9 @@ Usage:
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
"""
|
||||
import re
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -72,12 +72,13 @@ from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from lhotse.cut import Cut
|
||||
from lstm import LOG_EPSILON, stack_states, unstack_states
|
||||
from local.text_normalize import text_normalize
|
||||
from lstm import LOG_EPSILON, stack_states, unstack_states
|
||||
from stream import Stream
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -85,7 +86,6 @@ from icefall.checkpoint import (
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -622,8 +622,8 @@ def create_streaming_feature_extractor() -> Fbank:
|
||||
opts.mel_opts.num_bins = 80
|
||||
return Fbank(opts)
|
||||
|
||||
|
||||
def filter_zh_en(text: str):
|
||||
import re
|
||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||
|
||||
chars = pattern.split(text.upper())
|
||||
@ -634,6 +634,7 @@ def filter_zh_en(text: str):
|
||||
chars_new.extend(tokens)
|
||||
return chars_new
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
cuts: CutSet,
|
||||
model: nn.Module,
|
||||
|
@ -62,9 +62,9 @@ from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from lstm import RNN
|
||||
from local.text_normalize import text_normalize
|
||||
from local.tokenize_with_bpe_model import tokenize_by_bpe_model
|
||||
from lstm import RNN
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
@ -161,7 +161,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
"--is-pnnx",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Only used when exporting model with pnnx."
|
||||
help="Only used when exporting model with pnnx.",
|
||||
)
|
||||
|
||||
|
||||
@ -643,7 +643,7 @@ def compute_loss(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
#import pdb; pdb.set_trace()
|
||||
# import pdb; pdb.set_trace()
|
||||
y = graph_compiler.texts_to_ids_with_bpe(texts)
|
||||
if type(y) == list:
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user