mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
fix code style
This commit is contained in:
parent
8fb81e27cb
commit
dae3a310f4
@ -55,9 +55,9 @@ Usage:
|
|||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
"""
|
"""
|
||||||
import re
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -72,12 +72,13 @@ from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
|||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lstm import LOG_EPSILON, stack_states, unstack_states
|
|
||||||
from local.text_normalize import text_normalize
|
from local.text_normalize import text_normalize
|
||||||
|
from lstm import LOG_EPSILON, stack_states, unstack_states
|
||||||
from stream import Stream
|
from stream import Stream
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -85,7 +86,6 @@ from icefall.checkpoint import (
|
|||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.decode import one_best_decoding
|
from icefall.decode import one_best_decoding
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -622,10 +622,10 @@ def create_streaming_feature_extractor() -> Fbank:
|
|||||||
opts.mel_opts.num_bins = 80
|
opts.mel_opts.num_bins = 80
|
||||||
return Fbank(opts)
|
return Fbank(opts)
|
||||||
|
|
||||||
|
|
||||||
def filter_zh_en(text: str):
|
def filter_zh_en(text: str):
|
||||||
import re
|
|
||||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||||
|
|
||||||
chars = pattern.split(text.upper())
|
chars = pattern.split(text.upper())
|
||||||
chars_new = []
|
chars_new = []
|
||||||
for char in chars:
|
for char in chars:
|
||||||
@ -634,6 +634,7 @@ def filter_zh_en(text: str):
|
|||||||
chars_new.extend(tokens)
|
chars_new.extend(tokens)
|
||||||
return chars_new
|
return chars_new
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
cuts: CutSet,
|
cuts: CutSet,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -954,12 +955,12 @@ def main():
|
|||||||
text = text.strip("\n").strip("\t")
|
text = text.strip("\n").strip("\t")
|
||||||
c.supervisions[0].text = text_normalize(text)
|
c.supervisions[0].text = text_normalize(text)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
tal_csasr = TAL_CSASRAsrDataModule(args)
|
tal_csasr = TAL_CSASRAsrDataModule(args)
|
||||||
|
|
||||||
dev_cuts = tal_csasr.valid_cuts()
|
dev_cuts = tal_csasr.valid_cuts()
|
||||||
dev_cuts = dev_cuts.map(text_normalize_for_cut)
|
dev_cuts = dev_cuts.map(text_normalize_for_cut)
|
||||||
|
|
||||||
test_cuts = tal_csasr.test_cuts()
|
test_cuts = tal_csasr.test_cuts()
|
||||||
test_cuts = test_cuts.map(text_normalize_for_cut)
|
test_cuts = test_cuts.map(text_normalize_for_cut)
|
||||||
|
|
||||||
|
@ -62,9 +62,9 @@ from joiner import Joiner
|
|||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from lstm import RNN
|
|
||||||
from local.text_normalize import text_normalize
|
from local.text_normalize import text_normalize
|
||||||
from local.tokenize_with_bpe_model import tokenize_by_bpe_model
|
from local.tokenize_with_bpe_model import tokenize_by_bpe_model
|
||||||
|
from lstm import RNN
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -108,7 +108,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=512,
|
default=512,
|
||||||
help="Encoder output dimesion.",
|
help="Encoder output dimesion.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoder-dim",
|
"--decoder-dim",
|
||||||
type=int,
|
type=int,
|
||||||
@ -156,12 +156,12 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
`grad_norm_threshold * median`, where `median` is the median
|
`grad_norm_threshold * median`, where `median` is the median
|
||||||
value of gradient norms of all elememts in batch.""",
|
value of gradient norms of all elememts in batch.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--is-pnnx",
|
"--is-pnnx",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Only used when exporting model with pnnx."
|
help="Only used when exporting model with pnnx.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -643,7 +643,7 @@ def compute_loss(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
#import pdb; pdb.set_trace()
|
# import pdb; pdb.set_trace()
|
||||||
y = graph_compiler.texts_to_ids_with_bpe(texts)
|
y = graph_compiler.texts_to_ids_with_bpe(texts)
|
||||||
if type(y) == list:
|
if type(y) == list:
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
@ -805,7 +805,7 @@ def train_one_epoch(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx < cur_batch_idx:
|
if batch_idx < cur_batch_idx:
|
||||||
continue
|
continue
|
||||||
@ -1031,7 +1031,7 @@ def run(rank, world_size, args):
|
|||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
return 1.0 <= c.duration <= 20.0
|
return 1.0 <= c.duration <= 20.0
|
||||||
|
|
||||||
def text_normalize_for_cut(c: Cut):
|
def text_normalize_for_cut(c: Cut):
|
||||||
# Text normalize for each sample
|
# Text normalize for each sample
|
||||||
text = c.supervisions[0].text
|
text = c.supervisions[0].text
|
||||||
@ -1040,7 +1040,7 @@ def run(rank, world_size, args):
|
|||||||
text = tokenize_by_bpe_model(sp, text)
|
text = tokenize_by_bpe_model(sp, text)
|
||||||
c.supervisions[0].text = text
|
c.supervisions[0].text = text
|
||||||
return c
|
return c
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
train_cuts = train_cuts.map(text_normalize_for_cut)
|
train_cuts = train_cuts.map(text_normalize_for_cut)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user