fix code style

This commit is contained in:
marcoyang 2023-02-13 17:49:32 +08:00
parent 8fb81e27cb
commit dae3a310f4
2 changed files with 16 additions and 15 deletions

View File

@ -55,9 +55,9 @@ Usage:
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8
""" """
import re
import argparse import argparse
import logging import logging
import re
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -72,12 +72,13 @@ from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
from lhotse.cut import Cut from lhotse.cut import Cut
from lstm import LOG_EPSILON, stack_states, unstack_states
from local.text_normalize import text_normalize from local.text_normalize import text_normalize
from lstm import LOG_EPSILON, stack_states, unstack_states
from stream import Stream from stream import Stream
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -85,7 +86,6 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.decode import one_best_decoding from icefall.decode import one_best_decoding
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -622,10 +622,10 @@ def create_streaming_feature_extractor() -> Fbank:
opts.mel_opts.num_bins = 80 opts.mel_opts.num_bins = 80
return Fbank(opts) return Fbank(opts)
def filter_zh_en(text: str): def filter_zh_en(text: str):
import re
pattern = re.compile(r"([\u4e00-\u9fff])") pattern = re.compile(r"([\u4e00-\u9fff])")
chars = pattern.split(text.upper()) chars = pattern.split(text.upper())
chars_new = [] chars_new = []
for char in chars: for char in chars:
@ -634,6 +634,7 @@ def filter_zh_en(text: str):
chars_new.extend(tokens) chars_new.extend(tokens)
return chars_new return chars_new
def decode_dataset( def decode_dataset(
cuts: CutSet, cuts: CutSet,
model: nn.Module, model: nn.Module,
@ -954,12 +955,12 @@ def main():
text = text.strip("\n").strip("\t") text = text.strip("\n").strip("\t")
c.supervisions[0].text = text_normalize(text) c.supervisions[0].text = text_normalize(text)
return c return c
tal_csasr = TAL_CSASRAsrDataModule(args) tal_csasr = TAL_CSASRAsrDataModule(args)
dev_cuts = tal_csasr.valid_cuts() dev_cuts = tal_csasr.valid_cuts()
dev_cuts = dev_cuts.map(text_normalize_for_cut) dev_cuts = dev_cuts.map(text_normalize_for_cut)
test_cuts = tal_csasr.test_cuts() test_cuts = tal_csasr.test_cuts()
test_cuts = test_cuts.map(text_normalize_for_cut) test_cuts = test_cuts.map(text_normalize_for_cut)

View File

@ -62,9 +62,9 @@ from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from lstm import RNN
from local.text_normalize import text_normalize from local.text_normalize import text_normalize
from local.tokenize_with_bpe_model import tokenize_by_bpe_model from local.tokenize_with_bpe_model import tokenize_by_bpe_model
from lstm import RNN
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
@ -108,7 +108,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default=512, default=512,
help="Encoder output dimesion.", help="Encoder output dimesion.",
) )
parser.add_argument( parser.add_argument(
"--decoder-dim", "--decoder-dim",
type=int, type=int,
@ -156,12 +156,12 @@ def add_model_arguments(parser: argparse.ArgumentParser):
`grad_norm_threshold * median`, where `median` is the median `grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.""", value of gradient norms of all elememts in batch.""",
) )
parser.add_argument( parser.add_argument(
"--is-pnnx", "--is-pnnx",
type=str2bool, type=str2bool,
default=False, default=False,
help="Only used when exporting model with pnnx." help="Only used when exporting model with pnnx.",
) )
@ -643,7 +643,7 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
#import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
y = graph_compiler.texts_to_ids_with_bpe(texts) y = graph_compiler.texts_to_ids_with_bpe(texts)
if type(y) == list: if type(y) == list:
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
@ -805,7 +805,7 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0) cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx: if batch_idx < cur_batch_idx:
continue continue
@ -1031,7 +1031,7 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
return 1.0 <= c.duration <= 20.0 return 1.0 <= c.duration <= 20.0
def text_normalize_for_cut(c: Cut): def text_normalize_for_cut(c: Cut):
# Text normalize for each sample # Text normalize for each sample
text = c.supervisions[0].text text = c.supervisions[0].text
@ -1040,7 +1040,7 @@ def run(rank, world_size, args):
text = tokenize_by_bpe_model(sp, text) text = tokenize_by_bpe_model(sp, text)
c.supervisions[0].text = text c.supervisions[0].text = text
return c return c
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_cuts = train_cuts.map(text_normalize_for_cut) train_cuts = train_cuts.map(text_normalize_for_cut)