sort imports

This commit is contained in:
shaynemei 2022-08-01 20:49:52 -07:00
parent 9149a92dd8
commit 60621b242a
10 changed files with 33 additions and 25 deletions

View File

@ -66,14 +66,6 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from .asr_datamodule import GigaSpeechAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import (
@ -83,6 +75,14 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import GigaSpeechAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .gigaspeech_scoring import asr_text_post_processing
from .train import get_params, get_transducer_model

View File

@ -49,11 +49,12 @@ from pathlib import Path
import sentencepiece as spm
import torch
from .train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool
from .train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -56,14 +56,8 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import GigaSpeechAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from .model import Transducer
from .optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -77,6 +71,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import GigaSpeechAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]

View File

@ -36,10 +36,10 @@ import argparse
from pathlib import Path
from typing import Dict, List
from .generate_unique_lexicon import filter_multiple_pronunications
from icefall.lexicon import read_lexicon
from .generate_unique_lexicon import filter_multiple_pronunications
def get_args():
parser = argparse.ArgumentParser()

View File

@ -41,6 +41,9 @@ from typing import Dict, List, Tuple
import k2
import sentencepiece as spm
import torch
from icefall.utils import str2bool
from .prepare_lang import (
Lexicon,
add_disambig_symbols,
@ -49,8 +52,6 @@ from .prepare_lang import (
write_mapping,
)
from icefall.utils import str2bool
def lexicon_to_fst_no_sil(
lexicon: Lexicon,

View File

@ -21,11 +21,12 @@ from typing import Dict, List, Optional
import k2
import sentencepiece as spm
import torch
from .model import Transducer
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import add_eos, add_sos, get_texts
from .model import Transducer
def fast_beam_search_one_best(
model: Transducer,

View File

@ -21,6 +21,10 @@ import warnings
from typing import List, Optional, Tuple
import torch
from torch import Tensor, nn
from icefall.utils import make_pad_mask, subsequent_chunk_mask
from .encoder_interface import EncoderInterface
from .scaling import (
ActivationBalancer,
@ -30,9 +34,6 @@ from .scaling import (
ScaledConv2d,
ScaledLinear,
)
from torch import Tensor, nn
from icefall.utils import make_pad_mask, subsequent_chunk_mask
class Conformer(EncoderInterface):

View File

@ -17,6 +17,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .scaling import ScaledConv1d, ScaledEmbedding

View File

@ -16,6 +16,7 @@
import torch
import torch.nn as nn
from .scaling import ScaledLinear

View File

@ -18,11 +18,12 @@
import k2
import torch
import torch.nn as nn
from .encoder_interface import EncoderInterface
from .scaling import ScaledLinear
from icefall.utils import add_sos
from .encoder_interface import EncoderInterface
from .scaling import ScaledLinear
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf