use explicit relative imports for aishell4

This commit is contained in:
shaynemei 2022-08-01 21:37:24 -07:00
parent 4cb62a3202
commit 6df69603a8
7 changed files with 18 additions and 19 deletions

View File

@ -39,7 +39,7 @@ from typing import Dict, List
import k2 import k2
import torch import torch
from prepare_lang import ( from .prepare_lang import (
Lexicon, Lexicon,
add_disambig_symbols, add_disambig_symbols,
add_self_loops, add_self_loops,

View File

@ -22,7 +22,7 @@ import os
import tempfile import tempfile
import k2 import k2
from prepare_lang import ( from .prepare_lang import (
add_disambig_symbols, add_disambig_symbols,
generate_id_map, generate_id_map,
get_phones, get_phones,

View File

@ -61,8 +61,8 @@ from typing import Dict, List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import Aishell4AsrDataModule from .asr_datamodule import Aishell4AsrDataModule
from beam_search import ( from .beam_search import (
beam_search, beam_search,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
@ -70,8 +70,8 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from lhotse.cut import Cut from lhotse.cut import Cut
from local.text_normalize import text_normalize from ..local.text_normalize import text_normalize
from train import add_model_arguments, get_params, get_transducer_model from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -49,7 +49,7 @@ import logging
from pathlib import Path from pathlib import Path
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -72,7 +72,7 @@ import k2
import kaldifeat import kaldifeat
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from .beam_search import (
beam_search, beam_search,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
@ -80,7 +80,7 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from .train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon

View File

@ -23,7 +23,7 @@ To run this file, do:
python ./pruned_transducer_stateless5/test_model.py python ./pruned_transducer_stateless5/test_model.py
""" """
from train import get_params, get_transducer_model from .train import get_params, get_transducer_model
def test_model_1(): def test_model_1():

View File

@ -53,20 +53,19 @@ from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import Aishell4AsrDataModule from .asr_datamodule import Aishell4AsrDataModule
from conformer import Conformer from .conformer import Conformer
from decoder import Decoder from .decoder import Decoder
from joiner import Joiner from .joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from local.text_normalize import text_normalize from ..local.text_normalize import text_normalize
from model import Transducer from .model import Transducer
from optim import Eden, Eve from .optim import Eden, Eve, LRScheduler
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -86,7 +85,7 @@ from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler torch.optim.lr_scheduler._LRScheduler, LRScheduler
] ]