From 4cb62a3202a22c4212a9bbb8cd82119f2449799d Mon Sep 17 00:00:00 2001 From: shaynemei Date: Mon, 1 Aug 2022 21:35:06 -0700 Subject: [PATCH] use explicit relative imports for aishell2 --- .../ASR/pruned_transducer_stateless5/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless5/export.py | 2 +- .../pruned_transducer_stateless5/pretrained.py | 4 ++-- .../ASR/pruned_transducer_stateless5/train.py | 15 +++++++-------- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index f03bd34d3..f3b970527 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -110,8 +110,8 @@ from typing import Dict, List, Optional, Tuple import k2 import torch import torch.nn as nn -from asr_datamodule import AiShell2AsrDataModule -from beam_search import ( +from .asr_datamodule import AiShell2AsrDataModule +from .beam_search import ( beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, @@ -121,7 +121,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import add_model_arguments, get_params, get_transducer_model +from .train import add_model_arguments, get_params, get_transducer_model from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py index bc7bd71cb..b75a1c8b8 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py @@ -49,7 +49,7 @@ import logging from pathlib import Path import torch -from train import add_model_arguments, get_params, get_transducer_model +from .train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py index 09de1bece..16549153f 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py @@ -59,7 +59,7 @@ import k2 import kaldifeat import torch import torchaudio -from beam_search import ( +from .beam_search import ( beam_search, fast_beam_search_one_best, greedy_search, @@ -67,7 +67,7 @@ from beam_search import ( modified_beam_search, ) from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model +from .train import add_model_arguments, get_params, get_transducer_model from icefall.lexicon import Lexicon diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 838a0497f..feae9ed4b 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -61,19 +61,18 @@ from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 -import optim import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import AiShell2AsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner +from .asr_datamodule import AiShell2AsrDataModule +from .conformer import Conformer +from .decoder import Decoder +from .joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve +from .model import Transducer +from .optim import Eden, Eve, LRScheduler from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP @@ -93,7 +92,7 @@ from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler + torch.optim.lr_scheduler._LRScheduler, LRScheduler ]