mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
sort imports
This commit is contained in:
parent
5657fa44a4
commit
ca19b98949
@ -39,6 +39,7 @@ from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from .prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
|
@ -22,6 +22,7 @@ import os
|
||||
import tempfile
|
||||
|
||||
import k2
|
||||
|
||||
from .prepare_lang import (
|
||||
add_disambig_symbols,
|
||||
generate_id_map,
|
||||
|
@ -59,6 +59,16 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import Aidatatang_200zhAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
@ -69,19 +79,6 @@ from .beam_search import (
|
||||
)
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -46,12 +46,13 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -63,6 +63,10 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
@ -70,11 +74,8 @@ from .beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -56,15 +56,9 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import Aidatatang_200zhAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -80,6 +74,13 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
from .asr_datamodule import Aidatatang_200zhAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
@ -22,6 +22,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
|
@ -26,8 +26,6 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
@ -48,6 +46,9 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -24,12 +24,13 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
from .conformer import Conformer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -26,7 +26,6 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from .conformer import Conformer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import (
|
||||
@ -36,6 +35,8 @@ from icefall.decode import (
|
||||
)
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
from .conformer import Conformer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -16,10 +16,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from .subsampling import Conv2dSubsampling
|
||||
from .subsampling import VggSubsampling
|
||||
import torch
|
||||
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
|
||||
|
||||
def test_conv2d_subsampling():
|
||||
N = 3
|
||||
|
@ -18,6 +18,7 @@
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .transformer import (
|
||||
Transformer,
|
||||
add_eos,
|
||||
|
@ -26,13 +26,10 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -48,6 +45,10 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .transformer import Noam
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -20,9 +20,10 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .label_smoothing import LabelSmoothingLoss
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
Supervisions = Dict[str, torch.Tensor]
|
||||
|
@ -22,6 +22,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
|
@ -26,8 +26,6 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import (
|
||||
@ -49,6 +47,9 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -28,13 +28,10 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
@ -51,6 +48,10 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .transformer import Noam
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -20,9 +20,10 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .label_smoothing import LabelSmoothingLoss
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
Supervisions = Dict[str, torch.Tensor]
|
||||
|
@ -39,6 +39,7 @@ from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from .prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
|
@ -22,6 +22,7 @@ import os
|
||||
import tempfile
|
||||
|
||||
import k2
|
||||
|
||||
from .prepare_lang import (
|
||||
add_disambig_symbols,
|
||||
generate_id_map,
|
||||
|
@ -66,16 +66,6 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .aishell import AIShell
|
||||
from .asr_datamodule import AsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -92,6 +82,17 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .aishell import AIShell
|
||||
from .asr_datamodule import AsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -48,7 +48,6 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -59,6 +58,8 @@ from icefall.checkpoint import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -20,11 +20,12 @@ from typing import Optional
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .encoder_interface import EncoderInterface
|
||||
from .scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
from .encoder_interface import EncoderInterface
|
||||
from .scaling import ScaledLinear
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
@ -65,6 +65,10 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
@ -72,11 +76,8 @@ from .beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -61,19 +61,10 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from .aidatatang_200zh import AIDatatang200zh
|
||||
from .aishell import AIShell
|
||||
from .asr_datamodule import AsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -92,6 +83,15 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
from .aidatatang_200zh import AIDatatang200zh
|
||||
from .aishell import AIShell
|
||||
from .asr_datamodule import AsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
@ -24,8 +24,6 @@ from typing import Dict, List, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .model import TdnnLstm
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import get_lattice, nbest_decoding, one_best_decoding
|
||||
@ -39,6 +37,9 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .model import TdnnLstm
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -25,12 +25,13 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from .model import TdnnLstm
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
from .model import TdnnLstm
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -36,9 +36,7 @@ import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from lhotse.utils import fix_random_seed
|
||||
from .model import TdnnLstm
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
@ -49,12 +47,10 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
|
||||
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .model import TdnnLstm
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -19,6 +19,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .model import Transducer
|
||||
|
||||
|
||||
|
@ -22,10 +22,11 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from .transformer import Transformer
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
from .transformer import Transformer
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
"""
|
||||
|
@ -24,12 +24,6 @@ from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .beam_search import beam_search, greedy_search
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
@ -38,10 +32,17 @@ from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .beam_search import beam_search, greedy_search
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -49,16 +49,17 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -17,10 +17,11 @@
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
from .encoder_interface import EncoderInterface
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
@ -51,11 +51,6 @@ import kaldifeat
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from .beam_search import beam_search, greedy_search
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
@ -63,6 +58,12 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
from .beam_search import beam_search, greedy_search
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -23,6 +23,7 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .decoder import Decoder
|
||||
|
||||
|
||||
|
@ -30,18 +30,12 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from .model import Transducer
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -51,6 +45,13 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
from .transformer import Noam
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -20,11 +20,12 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .encoder_interface import EncoderInterface
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
from .encoder_interface import EncoderInterface
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
|
||||
|
||||
class Transformer(EncoderInterface):
|
||||
def __init__(
|
||||
|
@ -29,10 +29,7 @@ from lhotse.dataset import (
|
||||
K2SpeechRecognitionDataset,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import (
|
||||
OnTheFlyFeatures,
|
||||
PrecomputedFeatures,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
@ -63,6 +63,16 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .aishell import AIShell
|
||||
from .asr_datamodule import AsrDataModule
|
||||
from .beam_search import (
|
||||
@ -74,15 +84,6 @@ from .beam_search import (
|
||||
)
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -48,16 +48,17 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -20,10 +20,11 @@ from typing import Optional
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
from .encoder_interface import EncoderInterface
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
@ -65,6 +65,10 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
@ -72,11 +76,8 @@ from .beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -50,21 +50,13 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .aidatatang_200zh import AIDatatang200zh
|
||||
from .aishell import AIShell
|
||||
from .asr_datamodule import AsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from .model import Transducer
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -74,6 +66,15 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
from .aidatatang_200zh import AIDatatang200zh
|
||||
from .aishell import AIShell
|
||||
from .asr_datamodule import AsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
from .transformer import Noam
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -65,15 +65,6 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
@ -84,6 +75,16 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -48,16 +48,17 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -65,6 +65,10 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
@ -72,11 +76,8 @@ from .beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -23,6 +23,7 @@ To run this file, do:
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .decoder import Decoder
|
||||
|
||||
|
||||
|
@ -46,18 +46,12 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from .model import Transducer
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -67,6 +61,13 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
from .asr_datamodule import AishellAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
from .transformer import Noam
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -110,18 +110,6 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import AiShell2AsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
@ -139,6 +127,19 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import AiShell2AsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -49,7 +49,6 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -60,6 +59,8 @@ from icefall.checkpoint import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -59,6 +59,10 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
@ -66,11 +70,8 @@ from .beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -64,15 +64,9 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import AiShell2AsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -91,6 +85,13 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
from .asr_datamodule import AiShell2AsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
@ -39,6 +39,7 @@ from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from .prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
|
@ -22,6 +22,7 @@ import os
|
||||
import tempfile
|
||||
|
||||
import k2
|
||||
|
||||
from .prepare_lang import (
|
||||
add_disambig_symbols,
|
||||
generate_id_map,
|
||||
|
@ -61,17 +61,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import Aishell4AsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.cut import Cut
|
||||
from ..local.text_normalize import text_normalize
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -88,6 +78,17 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from ..local.text_normalize import text_normalize
|
||||
from .asr_datamodule import Aishell4AsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -49,7 +49,6 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -60,6 +59,8 @@ from icefall.checkpoint import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -72,6 +72,10 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
@ -79,11 +83,8 @@ from .beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -56,16 +56,9 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import Aishell4AsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from ..local.text_normalize import text_normalize
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -84,6 +77,14 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
from ..local.text_normalize import text_normalize
|
||||
from .asr_datamodule import Aishell4AsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
@ -39,6 +39,7 @@ from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from .prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
|
@ -22,6 +22,7 @@ import os
|
||||
import tempfile
|
||||
|
||||
import k2
|
||||
|
||||
from .prepare_lang import (
|
||||
add_disambig_symbols,
|
||||
generate_id_map,
|
||||
|
@ -59,6 +59,17 @@ from typing import Dict, List, Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lhotse.cut import Cut
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import AlimeetingAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
@ -67,22 +78,8 @@ from .beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.cut import Cut
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -46,12 +46,13 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -63,6 +63,10 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
@ -70,11 +74,8 @@ from .beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from .train import get_params, get_transducer_model
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -56,15 +56,9 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import AlimeetingAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -80,6 +74,13 @@ from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
from .asr_datamodule import AlimeetingAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .decoder import Decoder
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
@ -27,8 +27,6 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import GigaSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -49,6 +47,8 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
from .asr_datamodule import GigaSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .transformer import Noam
|
||||
|
||||
|
||||
|
@ -34,8 +34,6 @@ from pathlib import Path
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse import CutSet
|
||||
from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
|
||||
|
||||
@ -51,6 +49,9 @@ from icefall.utils import (
|
||||
setup_logger,
|
||||
)
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -21,6 +21,7 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
|
@ -26,8 +26,6 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -54,6 +52,9 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -24,12 +24,13 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
from .conformer import Conformer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -27,7 +27,6 @@ import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from .conformer import Conformer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import (
|
||||
@ -38,6 +37,8 @@ from icefall.decode import (
|
||||
)
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
from .conformer import Conformer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -18,6 +18,7 @@
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import torch
|
||||
|
||||
from .label_smoothing import LabelSmoothingLoss
|
||||
|
||||
torch_ver = LooseVersion(torch.__version__)
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
|
||||
|
||||
|
@ -18,6 +18,7 @@
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .transformer import (
|
||||
Transformer,
|
||||
add_eos,
|
||||
|
@ -38,15 +38,12 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -63,6 +60,10 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .transformer import Noam
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -19,9 +19,10 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .label_smoothing import LabelSmoothingLoss
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
Supervisions = Dict[str, torch.Tensor]
|
||||
|
@ -22,6 +22,8 @@ import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
@ -29,9 +31,7 @@ from .scaling import (
|
||||
ScaledConv1d,
|
||||
ScaledLinear,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from .subsampling import Conv2dSubsampling
|
||||
|
||||
from .transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
|
@ -28,17 +28,14 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
@ -62,6 +59,9 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -47,7 +47,6 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from .decode import get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -55,10 +54,11 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
from .conformer import Conformer
|
||||
from .decode import get_params
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -17,6 +17,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
@ -24,7 +26,6 @@ from .scaling import (
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
|
@ -57,19 +57,16 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall import diagnostics
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
@ -88,6 +85,10 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
@ -21,19 +21,18 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .label_smoothing import LabelSmoothingLoss
|
||||
from .subsampling import Conv2dSubsampling
|
||||
from .attention import MultiheadAttention
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .attention import MultiheadAttention
|
||||
from .label_smoothing import LabelSmoothingLoss
|
||||
from .scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
ScaledLinear,
|
||||
ScaledEmbedding,
|
||||
ScaledLinear,
|
||||
)
|
||||
|
||||
from .subsampling import Conv2dSubsampling
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
Supervisions = Dict[str, torch.Tensor]
|
||||
|
@ -22,6 +22,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
|
@ -26,8 +26,6 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
@ -50,6 +48,9 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -1,9 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from .subsampling import Conv2dSubsampling
|
||||
from .subsampling import VggSubsampling
|
||||
import torch
|
||||
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
|
||||
|
||||
def test_conv2d_subsampling():
|
||||
N = 3
|
||||
|
@ -1,17 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .transformer import (
|
||||
Transformer,
|
||||
add_eos,
|
||||
add_sos,
|
||||
decoder_padding_mask,
|
||||
encoder_padding_mask,
|
||||
generate_square_subsequent_mask,
|
||||
decoder_padding_mask,
|
||||
add_sos,
|
||||
add_eos,
|
||||
)
|
||||
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def test_encoder_padding_mask():
|
||||
supervisions = {
|
||||
|
@ -28,31 +28,23 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.ali import (
|
||||
convert_alignments_to_tensor,
|
||||
load_alignments,
|
||||
lookup_alignments,
|
||||
)
|
||||
from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.mmi import LFMMILoss
|
||||
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .transformer import Noam
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -28,31 +28,23 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from .transformer import Noam
|
||||
|
||||
from icefall.ali import (
|
||||
convert_alignments_to_tensor,
|
||||
load_alignments,
|
||||
lookup_alignments,
|
||||
)
|
||||
from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.mmi import LFMMILoss
|
||||
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .conformer import Conformer
|
||||
from .transformer import Noam
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -20,9 +20,10 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .subsampling import Conv2dSubsampling, VggSubsampling
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
Supervisions = Dict[str, torch.Tensor]
|
||||
|
||||
|
@ -80,15 +80,6 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -104,6 +95,16 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
|
@ -23,6 +23,9 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
from .encoder_interface import EncoderInterface
|
||||
from .scaling import (
|
||||
ActivationBalancer,
|
||||
@ -33,9 +36,6 @@ from .scaling import (
|
||||
ScaledLinear,
|
||||
)
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
LOG_EPSILON = math.log(1e-10)
|
||||
|
||||
|
||||
|
@ -64,7 +64,6 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -74,6 +73,8 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -20,10 +20,11 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from .beam_search import Hypothesis, HypothesisList
|
||||
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
from .beam_search import Hypothesis, HypothesisList
|
||||
|
||||
|
||||
class Stream(object):
|
||||
def __init__(
|
||||
|
@ -74,18 +74,13 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
from lhotse import CutSet
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from .emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from .stream import Stream
|
||||
from lhotse import CutSet
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -103,6 +98,12 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from .emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from .stream import Stream
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -19,6 +19,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from .emformer import ConvolutionModule, Emformer, stack_states, unstack_states
|
||||
|
||||
|
||||
|
@ -69,15 +69,9 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .decoder import Decoder
|
||||
from .emformer import Emformer
|
||||
from .joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -94,6 +88,13 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .decoder import Decoder
|
||||
from .emformer import Emformer
|
||||
from .joiner import Joiner
|
||||
from .model import Transducer
|
||||
from .optim import Eden, Eve, LRScheduler
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, LRScheduler
|
||||
]
|
||||
|
@ -80,15 +80,6 @@ import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -104,6 +95,16 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
|
@ -23,6 +23,9 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
from .encoder_interface import EncoderInterface
|
||||
from .scaling import (
|
||||
ActivationBalancer,
|
||||
@ -33,9 +36,6 @@ from .scaling import (
|
||||
ScaledLinear,
|
||||
)
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
LOG_EPSILON = math.log(1e-10)
|
||||
|
||||
|
||||
|
@ -64,7 +64,6 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -74,6 +73,8 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -74,18 +74,13 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
from lhotse import CutSet
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from .emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from .stream import Stream
|
||||
from lhotse import CutSet
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -103,6 +98,12 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from .asr_datamodule import LibriSpeechAsrDataModule
|
||||
from .beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
from .emformer import LOG_EPSILON, stack_states, unstack_states
|
||||
from .stream import Stream
|
||||
from .train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -19,6 +19,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from .emformer import ConvolutionModule, Emformer, stack_states, unstack_states
|
||||
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user