mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
use explicit relative imports
This commit is contained in:
parent
10846568b9
commit
5657fa44a4
@ -65,7 +65,7 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
@ -73,7 +73,7 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import get_params, get_transducer_model
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
@ -63,7 +63,7 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
|
@ -27,8 +27,8 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import GigaSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from .asr_datamodule import GigaSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
@ -34,8 +34,8 @@ from pathlib import Path
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse import CutSet
|
||||
from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
|
||||
|
||||
|
@ -21,7 +21,7 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
from .transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
|
@ -26,8 +26,8 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
|
@ -24,7 +24,7 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from conformer import Conformer
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
|
@ -27,7 +27,7 @@ import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from conformer import Conformer
|
||||
from .conformer import Conformer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import (
|
||||
|
@ -18,7 +18,7 @@
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import torch
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from .label_smoothing import LabelSmoothingLoss
|
||||
|
||||
torch_ver = LooseVersion(torch.__version__)
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
|
||||
|
||||
def test_conv2d_subsampling():
|
||||
|
@ -18,7 +18,7 @@
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformer import (
|
||||
from .transformer import (
|
||||
Transformer,
|
||||
add_eos,
|
||||
add_sos,
|
||||
|
@ -38,15 +38,15 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
|
@ -19,8 +19,8 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from .label_smoothing import LabelSmoothingLoss
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
|
@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn.init import xavier_normal_
|
||||
|
||||
from scaling import ScaledLinear
|
||||
from .scaling import ScaledLinear
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
|
@ -22,7 +22,7 @@ import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from scaling import (
|
||||
from .scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
@ -30,9 +30,9 @@ from scaling import (
|
||||
ScaledLinear,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from subsampling import Conv2dSubsampling
|
||||
from .subsampling import Conv2dSubsampling
|
||||
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
from .transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
|
@ -28,8 +28,8 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -47,7 +47,7 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from decode import get_params
|
||||
from .decode import get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -55,7 +55,7 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from conformer import Conformer
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from icefall.lexicon import Lexicon
|
||||
|
@ -17,7 +17,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from scaling import (
|
||||
from .scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
|
@ -54,16 +54,15 @@ from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from optim import Eden, Eve
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -90,7 +89,7 @@ from icefall.utils import (
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
||||
|
||||
|
@ -21,12 +21,12 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from subsampling import Conv2dSubsampling
|
||||
from attention import MultiheadAttention
|
||||
from .label_smoothing import LabelSmoothingLoss
|
||||
from .subsampling import Conv2dSubsampling
|
||||
from .attention import MultiheadAttention
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from scaling import (
|
||||
from .scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
|
@ -22,7 +22,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
from .transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
|
@ -26,8 +26,8 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from subsampling import Conv2dSubsampling
|
||||
from subsampling import VggSubsampling
|
||||
from .subsampling import Conv2dSubsampling
|
||||
from .subsampling import VggSubsampling
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from transformer import (
|
||||
from .transformer import (
|
||||
Transformer,
|
||||
encoder_padding_mask,
|
||||
generate_square_subsequent_mask,
|
||||
|
@ -28,13 +28,13 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.ali import (
|
||||
convert_alignments_to_tensor,
|
||||
|
@ -28,13 +28,13 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.ali import (
|
||||
convert_alignments_to_tensor,
|
||||
|
@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
|
@ -80,15 +80,15 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -23,8 +23,8 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
from .encoder_interface import EncoderInterface
|
||||
from .scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
|
@ -64,7 +64,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -20,7 +20,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from beam_search import Hypothesis, HypothesisList
|
||||
from .beam_search import Hypothesis, HypothesisList
|
||||
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
|
@ -79,13 +79,13 @@ import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from .emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from stream import Stream
|
||||
from .stream import Stream
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -19,7 +19,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
from emformer import ConvolutionModule, Emformer, stack_states, unstack_states
|
||||
from .emformer import ConvolutionModule, Emformer, stack_states, unstack_states
|
||||
|
||||
|
||||
def test_convolution_module_forward():
|
||||
|
@ -65,20 +65,19 @@ from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from emformer import Emformer
|
||||
from joiner import Joiner
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .decoder import Decoder
|
||||
from .emformer import Emformer
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -96,7 +95,7 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
||||
|
||||
|
@ -80,15 +80,15 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -23,8 +23,8 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
from .encoder_interface import EncoderInterface
|
||||
from .scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
|
@ -64,7 +64,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -79,13 +79,13 @@ import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from .emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from stream import Stream
|
||||
from .stream import Stream
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -19,7 +19,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
from emformer import ConvolutionModule, Emformer, stack_states, unstack_states
|
||||
from .emformer import ConvolutionModule, Emformer, stack_states, unstack_states
|
||||
|
||||
|
||||
def test_convolution_module_forward():
|
||||
|
@ -65,20 +65,19 @@ from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from emformer import Emformer
|
||||
from joiner import Joiner
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .decoder import Decoder
|
||||
from .emformer import Emformer
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -96,7 +95,7 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ import os
|
||||
import tempfile
|
||||
|
||||
import k2
|
||||
from prepare_lang import (
|
||||
from .prepare_lang import (
|
||||
add_disambig_symbols,
|
||||
generate_id_map,
|
||||
get_phones,
|
||||
|
@ -68,15 +68,15 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -19,8 +19,8 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from .encoder_interface import EncoderInterface
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
|
||||
try:
|
||||
from torchaudio.models import Emformer as _Emformer
|
||||
|
@ -50,7 +50,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -18,7 +18,7 @@
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from .encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
@ -23,7 +23,7 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
from emformer import Emformer, stack_states, unstack_states
|
||||
from .emformer import Emformer, stack_states, unstack_states
|
||||
|
||||
|
||||
def test_emformer():
|
||||
|
@ -24,7 +24,7 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
|
@ -45,15 +45,15 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from emformer import Emformer
|
||||
from joiner import Joiner
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .decoder import Decoder
|
||||
from .emformer import Emformer
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from noam import Noam
|
||||
from .model import Transducer
|
||||
from .noam import Noam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
@ -21,7 +21,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from model import Transducer
|
||||
from .model import Transducer
|
||||
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.utils import get_texts
|
||||
|
@ -115,8 +115,8 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
@ -126,7 +126,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -19,7 +19,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from beam_search import Hypothesis, HypothesisList
|
||||
from .beam_search import Hypothesis, HypothesisList
|
||||
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
|
@ -49,7 +49,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.utils import str2bool
|
||||
|
@ -18,7 +18,7 @@
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from .encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
@ -69,7 +69,7 @@ import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
@ -77,7 +77,7 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
@ -20,8 +20,8 @@ from typing import List
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from decode_stream import DecodeStream
|
||||
from .beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from .decode_stream import DecodeStream
|
||||
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.utils import get_texts
|
||||
|
@ -39,17 +39,17 @@ import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from streaming_beam_search import (
|
||||
from .streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -23,7 +23,7 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
from decoder import Decoder
|
||||
from .decoder import Decoder
|
||||
|
||||
|
||||
def test_decoder():
|
||||
|
@ -24,7 +24,7 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
|
@ -56,19 +56,19 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from .model import Transducer
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
|
@ -119,8 +119,8 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
@ -130,7 +130,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -49,7 +49,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -69,7 +69,7 @@ import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
@ -77,7 +77,7 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
@ -20,8 +20,8 @@ from typing import List
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from decode_stream import DecodeStream
|
||||
from .beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from .decode_stream import DecodeStream
|
||||
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.utils import get_texts
|
||||
|
@ -39,17 +39,17 @@ import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from streaming_beam_search import (
|
||||
from .streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -63,20 +63,19 @@ from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -91,7 +90,7 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
||||
|
||||
|
@ -66,8 +66,8 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from beam_search import (
|
||||
from .asr_datamodule import AsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
@ -75,9 +75,9 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from gigaspeech import GigaSpeech
|
||||
from .gigaspeech import GigaSpeech
|
||||
from gigaspeech_scoring import asr_text_post_processing
|
||||
from train import get_params, get_transducer_model
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -104,8 +104,8 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from beam_search import (
|
||||
from .asr_datamodule import AsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
@ -117,8 +117,8 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from librispeech import LibriSpeech
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .librispeech import LibriSpeech
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -50,7 +50,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -20,8 +20,8 @@ from typing import Optional
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
from .encoder_interface import EncoderInterface
|
||||
from .scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
@ -69,7 +69,7 @@ import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
@ -77,7 +77,7 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
@ -39,18 +39,18 @@ import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from .asr_datamodule import AsrDataModule
|
||||
from .decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from librispeech import LibriSpeech
|
||||
from streaming_beam_search import (
|
||||
from .librispeech import LibriSpeech
|
||||
from .streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -24,7 +24,7 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
from scaling import ActivationBalancer, ScaledConv1d, ScaledConv2d
|
||||
from .scaling import ActivationBalancer, ScaledConv1d, ScaledConv2d
|
||||
|
||||
|
||||
def test_scaled_conv1d():
|
||||
|
@ -61,18 +61,18 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from gigaspeech import GigaSpeech
|
||||
from joiner import Joiner
|
||||
from .asr_datamodule import AsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .gigaspeech import GigaSpeech
|
||||
from .joiner import Joiner
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from librispeech import LibriSpeech
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from .librispeech import LibriSpeech
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
@ -120,8 +120,8 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
@ -131,7 +131,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -50,7 +50,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -39,17 +39,17 @@ import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from streaming_beam_search import (
|
||||
from .streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -65,20 +65,19 @@ from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -96,7 +95,7 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
||||
|
||||
|
@ -105,8 +105,8 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
@ -116,7 +116,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -50,7 +50,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -69,7 +69,7 @@ import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
@ -77,7 +77,7 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -39,17 +39,17 @@ import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from streaming_beam_search import (
|
||||
from .streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -23,7 +23,7 @@ To run this file, do:
|
||||
python ./pruned_transducer_stateless4/test_model.py
|
||||
"""
|
||||
|
||||
from train import get_params, get_transducer_model
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model_1():
|
||||
|
@ -53,20 +53,19 @@ from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -84,7 +83,7 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
||||
|
||||
|
@ -21,8 +21,8 @@ import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
from .encoder_interface import EncoderInterface
|
||||
from .scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
|
@ -67,15 +67,15 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import get_params, get_transducer_model
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -49,7 +49,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
@ -21,9 +21,9 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from vq_utils import CodebookIndexExtractor
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from hubert_xlarge import HubertXlargeFineTuned
|
||||
from .vq_utils import CodebookIndexExtractor
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .hubert_xlarge import HubertXlargeFineTuned
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
|
@ -24,8 +24,8 @@ from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from hubert_xlarge import HubertXlargeFineTuned
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .hubert_xlarge import HubertXlargeFineTuned
|
||||
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
|
@ -18,8 +18,8 @@
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
from .encoder_interface import EncoderInterface
|
||||
from .scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
@ -24,7 +24,7 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
|
@ -64,21 +64,20 @@ from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut, MonoCut
|
||||
from lhotse.dataset.collation import collate_custom_field
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -96,7 +95,7 @@ from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
||||
|
||||
|
@ -30,8 +30,8 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import quantization
|
||||
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from hubert_xlarge import HubertXlargeFineTuned
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .hubert_xlarge import HubertXlargeFineTuned
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
|
@ -22,7 +22,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
from .transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py#L42
|
||||
|
@ -26,8 +26,8 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
|
@ -28,14 +28,14 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
|
@ -20,8 +20,8 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from .label_smoothing import LabelSmoothingLoss
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
|
@ -25,8 +25,8 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from model import TdnnLstm
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .model import TdnnLstm
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import (
|
||||
|
@ -26,7 +26,7 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from model import TdnnLstm
|
||||
from .model import TdnnLstm
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import (
|
||||
|
@ -37,10 +37,10 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import TdnnLstm
|
||||
from .model import TdnnLstm
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
@ -18,7 +18,7 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from model import Transducer
|
||||
from .model import Transducer
|
||||
|
||||
|
||||
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user