use explicit relative imports for aishell

This commit is contained in:
shaynemei 2022-08-01 21:34:58 -07:00
parent dd25072b3b
commit 51f2e377ba
40 changed files with 106 additions and 107 deletions

View File

@ -22,7 +22,7 @@ from typing import Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask from .transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer): class Conformer(Transformer):

View File

@ -26,8 +26,8 @@ from typing import Dict, List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from .asr_datamodule import AishellAsrDataModule
from conformer import Conformer from .conformer import Conformer
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint

View File

@ -24,7 +24,7 @@ import logging
from pathlib import Path from pathlib import Path
import torch import torch
from conformer import Conformer from .conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon

View File

@ -26,7 +26,7 @@ import k2
import kaldifeat import kaldifeat
import torch import torch
import torchaudio import torchaudio
from conformer import Conformer from .conformer import Conformer
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.decode import ( from icefall.decode import (

View File

@ -16,8 +16,8 @@
# limitations under the License. # limitations under the License.
from subsampling import Conv2dSubsampling from .subsampling import Conv2dSubsampling
from subsampling import VggSubsampling from .subsampling import VggSubsampling
import torch import torch

View File

@ -18,7 +18,7 @@
import torch import torch
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from transformer import ( from .transformer import (
Transformer, Transformer,
add_eos, add_eos,
add_sos, add_sos,

View File

@ -26,13 +26,13 @@ import k2
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from .asr_datamodule import AishellAsrDataModule
from conformer import Conformer from .conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from .transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint

View File

@ -20,8 +20,8 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from label_smoothing import LabelSmoothingLoss from .label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling from .subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed. # Note: TorchScript requires Dict/List/etc. to be fully typed.

View File

@ -22,7 +22,7 @@ from typing import Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask from .transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer): class Conformer(Transformer):

View File

@ -26,8 +26,8 @@ from typing import Dict, List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from .asr_datamodule import AishellAsrDataModule
from conformer import Conformer from .conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import ( from icefall.decode import (

View File

@ -28,13 +28,13 @@ import k2
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from .asr_datamodule import AishellAsrDataModule
from conformer import Conformer from .conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from .transformer import Noam
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl

View File

@ -20,8 +20,8 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from label_smoothing import LabelSmoothingLoss from .label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling from .subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed. # Note: TorchScript requires Dict/List/etc. to be fully typed.

View File

@ -39,7 +39,7 @@ from typing import Dict, List
import k2 import k2
import torch import torch
from prepare_lang import ( from .prepare_lang import (
Lexicon, Lexicon,
add_disambig_symbols, add_disambig_symbols,
add_self_loops, add_self_loops,

View File

@ -22,7 +22,7 @@ import os
import tempfile import tempfile
import k2 import k2
from prepare_lang import ( from .prepare_lang import (
add_disambig_symbols, add_disambig_symbols,
generate_id_map, generate_id_map,
get_phones, get_phones,

View File

@ -66,16 +66,16 @@ from typing import Dict, List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from aishell import AIShell from .aishell import AIShell
from asr_datamodule import AsrDataModule from .asr_datamodule import AsrDataModule
from beam_search import ( from .beam_search import (
beam_search, beam_search,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from train import add_model_arguments, get_params, get_transducer_model from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -48,7 +48,7 @@ import logging
from pathlib import Path from pathlib import Path
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -20,8 +20,8 @@ from typing import Optional
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from .encoder_interface import EncoderInterface
from scaling import ScaledLinear from .scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos

View File

@ -65,7 +65,7 @@ import k2
import kaldifeat import kaldifeat
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from .beam_search import (
beam_search, beam_search,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
@ -73,7 +73,7 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from .train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon

View File

@ -58,23 +58,22 @@ from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from aidatatang_200zh import AIDatatang200zh from .aidatatang_200zh import AIDatatang200zh
from aishell import AIShell from .aishell import AIShell
from asr_datamodule import AsrDataModule from .asr_datamodule import AsrDataModule
from conformer import Conformer from .conformer import Conformer
from decoder import Decoder from .decoder import Decoder
from joiner import Joiner from .joiner import Joiner
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from .model import Transducer
from optim import Eden, Eve from .optim import Eden, Eve, LRScheduler
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -94,7 +93,7 @@ from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler torch.optim.lr_scheduler._LRScheduler, LRScheduler
] ]

View File

@ -24,8 +24,8 @@ from typing import Dict, List, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from .asr_datamodule import AishellAsrDataModule
from model import TdnnLstm from .model import TdnnLstm
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice, nbest_decoding, one_best_decoding from icefall.decode import get_lattice, nbest_decoding, one_best_decoding

View File

@ -25,7 +25,7 @@ import k2
import kaldifeat import kaldifeat
import torch import torch
import torchaudio import torchaudio
from model import TdnnLstm from .model import TdnnLstm
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding from icefall.decode import get_lattice, one_best_decoding

View File

@ -36,9 +36,9 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from asr_datamodule import AishellAsrDataModule from .asr_datamodule import AishellAsrDataModule
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import TdnnLstm from .model import TdnnLstm
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR

View File

@ -19,7 +19,7 @@ from typing import Dict, List, Optional
import numpy as np import numpy as np
import torch import torch
from model import Transducer from .model import Transducer
def greedy_search( def greedy_search(

View File

@ -22,7 +22,7 @@ from typing import Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from transformer import Transformer from .transformer import Transformer
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask

View File

@ -24,12 +24,12 @@ from typing import Dict, List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from .asr_datamodule import AishellAsrDataModule
from beam_search import beam_search, greedy_search from .beam_search import beam_search, greedy_search
from conformer import Conformer from .conformer import Conformer
from decoder import Decoder from .decoder import Decoder
from joiner import Joiner from .joiner import Joiner
from model import Transducer from .model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info

View File

@ -49,10 +49,10 @@ from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from .conformer import Conformer
from decoder import Decoder from .decoder import Decoder
from joiner import Joiner from .joiner import Joiner
from model import Transducer from .model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info

View File

@ -17,7 +17,7 @@
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from .encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos

View File

@ -51,11 +51,11 @@ import kaldifeat
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search from .beam_search import beam_search, greedy_search
from conformer import Conformer from .conformer import Conformer
from decoder import Decoder from .decoder import Decoder
from joiner import Joiner from .joiner import Joiner
from model import Transducer from .model import Transducer
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler

View File

@ -23,7 +23,7 @@ To run this file, do:
""" """
import torch import torch
from decoder import Decoder from .decoder import Decoder
def test_decoder(): def test_decoder():

View File

@ -30,18 +30,18 @@ import k2
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from .asr_datamodule import AishellAsrDataModule
from conformer import Conformer from .conformer import Conformer
from decoder import Decoder from .decoder import Decoder
from joiner import Joiner from .joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from .model import Transducer
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from .transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint

View File

@ -20,8 +20,8 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from .encoder_interface import EncoderInterface
from subsampling import Conv2dSubsampling, VggSubsampling from .subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask

View File

@ -63,16 +63,16 @@ from typing import Dict, List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from aishell import AIShell from .aishell import AIShell
from asr_datamodule import AsrDataModule from .asr_datamodule import AsrDataModule
from beam_search import ( from .beam_search import (
beam_search, beam_search,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from train import get_params, get_transducer_model from .train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon

View File

@ -48,10 +48,10 @@ from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from .conformer import Conformer
from decoder import Decoder from .decoder import Decoder
from joiner import Joiner from .joiner import Joiner
from model import Transducer from .model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info

View File

@ -20,7 +20,7 @@ from typing import Optional
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from .encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos

View File

@ -65,7 +65,7 @@ import k2
import kaldifeat import kaldifeat
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from .beam_search import (
beam_search, beam_search,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
@ -73,7 +73,7 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model from .train import get_params, get_transducer_model
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon

View File

@ -50,21 +50,21 @@ import k2
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from aidatatang_200zh import AIDatatang200zh from .aidatatang_200zh import AIDatatang200zh
from aishell import AIShell from .aishell import AIShell
from asr_datamodule import AsrDataModule from .asr_datamodule import AsrDataModule
from conformer import Conformer from .conformer import Conformer
from decoder import Decoder from .decoder import Decoder
from joiner import Joiner from .joiner import Joiner
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from .model import Transducer
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from .transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint

View File

@ -65,15 +65,15 @@ from typing import Dict, List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from .asr_datamodule import AishellAsrDataModule
from beam_search import ( from .beam_search import (
beam_search, beam_search,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from train import get_params, get_transducer_model from .train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon

View File

@ -48,10 +48,10 @@ from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from .conformer import Conformer
from decoder import Decoder from .decoder import Decoder
from joiner import Joiner from .joiner import Joiner
from model import Transducer from .model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info

View File

@ -23,7 +23,7 @@ To run this file, do:
""" """
import torch import torch
from decoder import Decoder from .decoder import Decoder
def test_decoder(): def test_decoder():

View File

@ -46,18 +46,18 @@ import k2
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from .asr_datamodule import AishellAsrDataModule
from conformer import Conformer from .conformer import Conformer
from decoder import Decoder from .decoder import Decoder
from joiner import Joiner from .joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from .model import Transducer
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from .transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint