use explicit relative imports for aishell
This commit is contained in:
parent
dd25072b3b
commit
51f2e377ba
@ -22,7 +22,7 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
from .transformer import Supervisions, Transformer, encoder_padding_mask
|
||||||
|
|
||||||
|
|
||||||
class Conformer(Transformer):
|
class Conformer(Transformer):
|
||||||
|
|||||||
@ -26,8 +26,8 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from .asr_datamodule import AishellAsrDataModule
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
|
|||||||
@ -24,7 +24,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
|||||||
@ -26,7 +26,7 @@ import k2
|
|||||||
import kaldifeat
|
import kaldifeat
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
|
|||||||
@ -16,8 +16,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from subsampling import Conv2dSubsampling
|
from .subsampling import Conv2dSubsampling
|
||||||
from subsampling import VggSubsampling
|
from .subsampling import VggSubsampling
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from transformer import (
|
from .transformer import (
|
||||||
Transformer,
|
Transformer,
|
||||||
add_eos,
|
add_eos,
|
||||||
add_sos,
|
add_sos,
|
||||||
|
|||||||
@ -26,13 +26,13 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from .asr_datamodule import AishellAsrDataModule
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from .transformer import Noam
|
||||||
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
|||||||
@ -20,8 +20,8 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from label_smoothing import LabelSmoothingLoss
|
from .label_smoothing import LabelSmoothingLoss
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
from .transformer import Supervisions, Transformer, encoder_padding_mask
|
||||||
|
|
||||||
|
|
||||||
class Conformer(Transformer):
|
class Conformer(Transformer):
|
||||||
|
|||||||
@ -26,8 +26,8 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from .asr_datamodule import AishellAsrDataModule
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
|
|||||||
@ -28,13 +28,13 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from .asr_datamodule import AishellAsrDataModule
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from .transformer import Noam
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
|
|||||||
@ -20,8 +20,8 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from label_smoothing import LabelSmoothingLoss
|
from .label_smoothing import LabelSmoothingLoss
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||||
|
|||||||
@ -39,7 +39,7 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from prepare_lang import (
|
from .prepare_lang import (
|
||||||
Lexicon,
|
Lexicon,
|
||||||
add_disambig_symbols,
|
add_disambig_symbols,
|
||||||
add_self_loops,
|
add_self_loops,
|
||||||
|
|||||||
@ -22,7 +22,7 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
from prepare_lang import (
|
from .prepare_lang import (
|
||||||
add_disambig_symbols,
|
add_disambig_symbols,
|
||||||
generate_id_map,
|
generate_id_map,
|
||||||
get_phones,
|
get_phones,
|
||||||
|
|||||||
@ -66,16 +66,16 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from aishell import AIShell
|
from .aishell import AIShell
|
||||||
from asr_datamodule import AsrDataModule
|
from .asr_datamodule import AsrDataModule
|
||||||
from beam_search import (
|
from .beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from .train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
|||||||
@ -48,7 +48,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from .train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
|||||||
@ -20,8 +20,8 @@ from typing import Optional
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from .encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from .scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
|
|||||||
@ -65,7 +65,7 @@ import k2
|
|||||||
import kaldifeat
|
import kaldifeat
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from .beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
@ -73,7 +73,7 @@ from beam_search import (
|
|||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from .train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|||||||
@ -58,23 +58,22 @@ from shutil import copyfile
|
|||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from aidatatang_200zh import AIDatatang200zh
|
from .aidatatang_200zh import AIDatatang200zh
|
||||||
from aishell import AIShell
|
from .aishell import AIShell
|
||||||
from asr_datamodule import AsrDataModule
|
from .asr_datamodule import AsrDataModule
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
from decoder import Decoder
|
from .decoder import Decoder
|
||||||
from joiner import Joiner
|
from .joiner import Joiner
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from .model import Transducer
|
||||||
from optim import Eden, Eve
|
from .optim import Eden, Eve, LRScheduler
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@ -94,7 +93,7 @@ from icefall.lexicon import Lexicon
|
|||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[
|
||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -24,8 +24,8 @@ from typing import Dict, List, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from .asr_datamodule import AishellAsrDataModule
|
||||||
from model import TdnnLstm
|
from .model import TdnnLstm
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.decode import get_lattice, nbest_decoding, one_best_decoding
|
from icefall.decode import get_lattice, nbest_decoding, one_best_decoding
|
||||||
|
|||||||
@ -25,7 +25,7 @@ import k2
|
|||||||
import kaldifeat
|
import kaldifeat
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from model import TdnnLstm
|
from .model import TdnnLstm
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.decode import get_lattice, one_best_decoding
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
|
|||||||
@ -36,9 +36,9 @@ import torch.distributed as dist
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from .asr_datamodule import AishellAsrDataModule
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import TdnnLstm
|
from .model import TdnnLstm
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.optim.lr_scheduler import StepLR
|
from torch.optim.lr_scheduler import StepLR
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from typing import Dict, List, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from .model import Transducer
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformer import Transformer
|
from .transformer import Transformer
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|||||||
@ -24,12 +24,12 @@ from typing import Dict, List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from .asr_datamodule import AishellAsrDataModule
|
||||||
from beam_search import beam_search, greedy_search
|
from .beam_search import beam_search, greedy_search
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
from decoder import Decoder
|
from .decoder import Decoder
|
||||||
from joiner import Joiner
|
from .joiner import Joiner
|
||||||
from model import Transducer
|
from .model import Transducer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
|||||||
@ -49,10 +49,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
from decoder import Decoder
|
from .decoder import Decoder
|
||||||
from joiner import Joiner
|
from .joiner import Joiner
|
||||||
from model import Transducer
|
from .model import Transducer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
|||||||
@ -17,7 +17,7 @@
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from .encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
|
|||||||
@ -51,11 +51,11 @@ import kaldifeat
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import beam_search, greedy_search
|
from .beam_search import beam_search, greedy_search
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
from decoder import Decoder
|
from .decoder import Decoder
|
||||||
from joiner import Joiner
|
from .joiner import Joiner
|
||||||
from model import Transducer
|
from .model import Transducer
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
|
|||||||
@ -23,7 +23,7 @@ To run this file, do:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from decoder import Decoder
|
from .decoder import Decoder
|
||||||
|
|
||||||
|
|
||||||
def test_decoder():
|
def test_decoder():
|
||||||
|
|||||||
@ -30,18 +30,18 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from .asr_datamodule import AishellAsrDataModule
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
from decoder import Decoder
|
from .decoder import Decoder
|
||||||
from joiner import Joiner
|
from .joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from .model import Transducer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from .transformer import Noam
|
||||||
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
|||||||
@ -20,8 +20,8 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from .encoder_interface import EncoderInterface
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|||||||
@ -63,16 +63,16 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from aishell import AIShell
|
from .aishell import AIShell
|
||||||
from asr_datamodule import AsrDataModule
|
from .asr_datamodule import AsrDataModule
|
||||||
from beam_search import (
|
from .beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from train import get_params, get_transducer_model
|
from .train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
|||||||
@ -48,10 +48,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
from decoder import Decoder
|
from .decoder import Decoder
|
||||||
from joiner import Joiner
|
from .joiner import Joiner
|
||||||
from model import Transducer
|
from .model import Transducer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from typing import Optional
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from .encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
|
|||||||
@ -65,7 +65,7 @@ import k2
|
|||||||
import kaldifeat
|
import kaldifeat
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from .beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
@ -73,7 +73,7 @@ from beam_search import (
|
|||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params, get_transducer_model
|
from .train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|||||||
@ -50,21 +50,21 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from aidatatang_200zh import AIDatatang200zh
|
from .aidatatang_200zh import AIDatatang200zh
|
||||||
from aishell import AIShell
|
from .aishell import AIShell
|
||||||
from asr_datamodule import AsrDataModule
|
from .asr_datamodule import AsrDataModule
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
from decoder import Decoder
|
from .decoder import Decoder
|
||||||
from joiner import Joiner
|
from .joiner import Joiner
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from .model import Transducer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from .transformer import Noam
|
||||||
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
|||||||
@ -65,15 +65,15 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from .asr_datamodule import AishellAsrDataModule
|
||||||
from beam_search import (
|
from .beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from train import get_params, get_transducer_model
|
from .train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
|||||||
@ -48,10 +48,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
from decoder import Decoder
|
from .decoder import Decoder
|
||||||
from joiner import Joiner
|
from .joiner import Joiner
|
||||||
from model import Transducer
|
from .model import Transducer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
|||||||
@ -23,7 +23,7 @@ To run this file, do:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from decoder import Decoder
|
from .decoder import Decoder
|
||||||
|
|
||||||
|
|
||||||
def test_decoder():
|
def test_decoder():
|
||||||
|
|||||||
@ -46,18 +46,18 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from .asr_datamodule import AishellAsrDataModule
|
||||||
from conformer import Conformer
|
from .conformer import Conformer
|
||||||
from decoder import Decoder
|
from .decoder import Decoder
|
||||||
from joiner import Joiner
|
from .joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from .model import Transducer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from .transformer import Noam
|
||||||
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
|||||||
Reference in New Issue
Block a user