sort imports

This commit is contained in:
shaynemei 2022-08-02 17:45:23 -07:00
parent 5657fa44a4
commit ca19b98949
258 changed files with 1035 additions and 907 deletions

View File

@ -39,6 +39,7 @@ from typing import Dict, List
import k2
import torch
from .prepare_lang import (
Lexicon,
add_disambig_symbols,

View File

@ -22,6 +22,7 @@ import os
import tempfile
import k2
from .prepare_lang import (
add_disambig_symbols,
generate_id_map,

View File

@ -59,6 +59,16 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
from .asr_datamodule import Aidatatang_200zhAsrDataModule
from .beam_search import (
beam_search,
@ -69,19 +79,6 @@ from .beam_search import (
)
from .train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -46,12 +46,13 @@ import logging
from pathlib import Path
import torch
from .train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from .train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -63,6 +63,10 @@ import k2
import kaldifeat
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
@ -70,11 +74,8 @@ from .beam_search import (
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from .train import get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -56,15 +56,9 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import Aidatatang_200zhAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -80,6 +74,13 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import Aidatatang_200zhAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]

View File

@ -22,6 +22,7 @@ from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from .transformer import Supervisions, Transformer, encoder_padding_mask

View File

@ -26,8 +26,6 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -48,6 +46,9 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -24,12 +24,13 @@ import logging
from pathlib import Path
import torch
from .conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -26,7 +26,6 @@ import k2
import kaldifeat
import torch
import torchaudio
from .conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import (
@ -36,6 +35,8 @@ from icefall.decode import (
)
from icefall.utils import AttributeDict, get_texts
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -16,10 +16,10 @@
# limitations under the License.
from .subsampling import Conv2dSubsampling
from .subsampling import VggSubsampling
import torch
from .subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling():
N = 3

View File

@ -18,6 +18,7 @@
import torch
from torch.nn.utils.rnn import pad_sequence
from .transformer import (
Transformer,
add_eos,

View File

@ -26,13 +26,10 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from .transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -48,6 +45,10 @@ from icefall.utils import (
str2bool,
)
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -20,9 +20,10 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from .label_smoothing import LabelSmoothingLoss
from .subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]

View File

@ -22,6 +22,7 @@ from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from .transformer import Supervisions, Transformer, encoder_padding_mask

View File

@ -26,8 +26,6 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import (
@ -49,6 +47,9 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -28,13 +28,10 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from .transformer import Noam
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
@ -51,6 +48,10 @@ from icefall.utils import (
str2bool,
)
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -20,9 +20,10 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from .label_smoothing import LabelSmoothingLoss
from .subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]

View File

@ -39,6 +39,7 @@ from typing import Dict, List
import k2
import torch
from .prepare_lang import (
Lexicon,
add_disambig_symbols,

View File

@ -22,6 +22,7 @@ import os
import tempfile
import k2
from .prepare_lang import (
add_disambig_symbols,
generate_id_map,

View File

@ -66,16 +66,6 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from .aishell import AIShell
from .asr_datamodule import AsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -92,6 +82,17 @@ from icefall.utils import (
write_error_stats,
)
from .aishell import AIShell
from .asr_datamodule import AsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -48,7 +48,6 @@ import logging
from pathlib import Path
import torch
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -59,6 +58,8 @@ from icefall.checkpoint import (
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -20,11 +20,12 @@ from typing import Optional
import k2
import torch
import torch.nn as nn
from .encoder_interface import EncoderInterface
from .scaling import ScaledLinear
from icefall.utils import add_sos
from .encoder_interface import EncoderInterface
from .scaling import ScaledLinear
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf

View File

@ -65,6 +65,10 @@ import k2
import kaldifeat
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
@ -72,11 +76,8 @@ from .beam_search import (
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -61,19 +61,10 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .aidatatang_200zh import AIDatatang200zh
from .aishell import AIShell
from .asr_datamodule import AsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -92,6 +83,15 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .aidatatang_200zh import AIDatatang200zh
from .aishell import AIShell
from .asr_datamodule import AsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]

View File

@ -24,8 +24,6 @@ from typing import Dict, List, Tuple
import k2
import torch
import torch.nn as nn
from .asr_datamodule import AishellAsrDataModule
from .model import TdnnLstm
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice, nbest_decoding, one_best_decoding
@ -39,6 +37,9 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import AishellAsrDataModule
from .model import TdnnLstm
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -25,12 +25,13 @@ import k2
import kaldifeat
import torch
import torchaudio
from .model import TdnnLstm
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
from .model import TdnnLstm
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -36,9 +36,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from .asr_datamodule import AishellAsrDataModule
from lhotse.utils import fix_random_seed
from .model import TdnnLstm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
@ -49,12 +47,10 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
from .asr_datamodule import AishellAsrDataModule
from .model import TdnnLstm
def get_parser():

View File

@ -19,6 +19,7 @@ from typing import Dict, List, Optional
import numpy as np
import torch
from .model import Transducer

View File

@ -22,10 +22,11 @@ from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from .transformer import Transformer
from icefall.utils import make_pad_mask
from .transformer import Transformer
class Conformer(Transformer):
"""

View File

@ -24,12 +24,6 @@ from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from .asr_datamodule import AishellAsrDataModule
from .beam_search import beam_search, greedy_search
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
@ -38,10 +32,17 @@ from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
str2bool,
write_error_stats,
)
from .asr_datamodule import AishellAsrDataModule
from .beam_search import beam_search, greedy_search
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -49,16 +49,17 @@ from pathlib import Path
import torch
import torch.nn as nn
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -17,10 +17,11 @@
import k2
import torch
import torch.nn as nn
from .encoder_interface import EncoderInterface
from icefall.utils import add_sos
from .encoder_interface import EncoderInterface
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf

View File

@ -51,11 +51,6 @@ import kaldifeat
import torch
import torch.nn as nn
import torchaudio
from .beam_search import beam_search, greedy_search
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
@ -63,6 +58,12 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
from .beam_search import beam_search, greedy_search
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -23,6 +23,7 @@ To run this file, do:
"""
import torch
from .decoder import Decoder

View File

@ -30,18 +30,12 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from .model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from .transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -51,6 +45,13 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -20,11 +20,12 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
from .encoder_interface import EncoderInterface
from .subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import make_pad_mask
from .encoder_interface import EncoderInterface
from .subsampling import Conv2dSubsampling, VggSubsampling
class Transformer(EncoderInterface):
def __init__(

View File

@ -29,10 +29,7 @@ from lhotse.dataset import (
K2SpeechRecognitionDataset,
SpecAugment,
)
from lhotse.dataset.input_strategies import (
OnTheFlyFeatures,
PrecomputedFeatures,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
from torch.utils.data import DataLoader
from icefall.utils import str2bool

View File

@ -63,6 +63,16 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
from .aishell import AIShell
from .asr_datamodule import AsrDataModule
from .beam_search import (
@ -74,15 +84,6 @@ from .beam_search import (
)
from .train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -48,16 +48,17 @@ from pathlib import Path
import torch
import torch.nn as nn
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -20,10 +20,11 @@ from typing import Optional
import k2
import torch
import torch.nn as nn
from .encoder_interface import EncoderInterface
from icefall.utils import add_sos
from .encoder_interface import EncoderInterface
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf

View File

@ -65,6 +65,10 @@ import k2
import kaldifeat
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
@ -72,11 +76,8 @@ from .beam_search import (
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from .train import get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -50,21 +50,13 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .aidatatang_200zh import AIDatatang200zh
from .aishell import AIShell
from .asr_datamodule import AsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from .model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from .transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -74,6 +66,15 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .aidatatang_200zh import AIDatatang200zh
from .aishell import AIShell
from .asr_datamodule import AsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -65,15 +65,6 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from .asr_datamodule import AishellAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
@ -84,6 +75,16 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import AishellAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -48,16 +48,17 @@ from pathlib import Path
import torch
import torch.nn as nn
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -65,6 +65,10 @@ import k2
import kaldifeat
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
@ -72,11 +76,8 @@ from .beam_search import (
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from .train import get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -23,6 +23,7 @@ To run this file, do:
"""
import torch
from .decoder import Decoder

View File

@ -46,18 +46,12 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from .model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from .transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -67,6 +61,13 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -110,18 +110,6 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from .asr_datamodule import AiShell2AsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
@ -139,6 +127,19 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import AiShell2AsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -49,7 +49,6 @@ import logging
from pathlib import Path
import torch
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -60,6 +59,8 @@ from icefall.checkpoint import (
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -59,6 +59,10 @@ import k2
import kaldifeat
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
@ -66,11 +70,8 @@ from .beam_search import (
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -64,15 +64,9 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import AiShell2AsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -91,6 +85,13 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import AiShell2AsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]

View File

@ -39,6 +39,7 @@ from typing import Dict, List
import k2
import torch
from .prepare_lang import (
Lexicon,
add_disambig_symbols,

View File

@ -22,6 +22,7 @@ import os
import tempfile
import k2
from .prepare_lang import (
add_disambig_symbols,
generate_id_map,

View File

@ -61,17 +61,7 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from .asr_datamodule import Aishell4AsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from ..local.text_normalize import text_normalize
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -88,6 +78,17 @@ from icefall.utils import (
write_error_stats,
)
from ..local.text_normalize import text_normalize
from .asr_datamodule import Aishell4AsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -49,7 +49,6 @@ import logging
from pathlib import Path
import torch
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -60,6 +59,8 @@ from icefall.checkpoint import (
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -72,6 +72,10 @@ import k2
import kaldifeat
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
@ -79,11 +83,8 @@ from .beam_search import (
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -56,16 +56,9 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import Aishell4AsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from ..local.text_normalize import text_normalize
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -84,6 +77,14 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from ..local.text_normalize import text_normalize
from .asr_datamodule import Aishell4AsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]

View File

@ -39,6 +39,7 @@ from typing import Dict, List
import k2
import torch
from .prepare_lang import (
Lexicon,
add_disambig_symbols,

View File

@ -22,6 +22,7 @@ import os
import tempfile
import k2
from .prepare_lang import (
add_disambig_symbols,
generate_id_map,

View File

@ -59,6 +59,17 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from lhotse.cut import Cut
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
from .asr_datamodule import AlimeetingAsrDataModule
from .beam_search import (
beam_search,
@ -67,22 +78,8 @@ from .beam_search import (
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from .train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -46,12 +46,13 @@ import logging
from pathlib import Path
import torch
from .train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from .train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -63,6 +63,10 @@ import k2
import kaldifeat
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
@ -70,11 +74,8 @@ from .beam_search import (
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from .train import get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -56,15 +56,9 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import AlimeetingAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -80,6 +74,13 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import AlimeetingAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]

View File

@ -27,8 +27,6 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import GigaSpeechAsrDataModule
from .conformer import Conformer
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
@ -49,6 +47,8 @@ from icefall.utils import (
str2bool,
)
from .asr_datamodule import GigaSpeechAsrDataModule
from .conformer import Conformer
from .transformer import Noam

View File

@ -34,8 +34,6 @@ from pathlib import Path
import k2
import numpy as np
import torch
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from lhotse import CutSet
from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
@ -51,6 +49,9 @@ from icefall.utils import (
setup_logger,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -21,6 +21,7 @@ from typing import Optional, Tuple, Union
import torch
from torch import Tensor, nn
from .transformer import Supervisions, Transformer, encoder_padding_mask

View File

@ -26,8 +26,6 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -54,6 +52,9 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -24,12 +24,13 @@ import logging
from pathlib import Path
import torch
from .conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -27,7 +27,6 @@ import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from .conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import (
@ -38,6 +37,8 @@ from icefall.decode import (
)
from icefall.utils import AttributeDict, get_texts
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -18,6 +18,7 @@
from distutils.version import LooseVersion
import torch
from .label_smoothing import LabelSmoothingLoss
torch_ver = LooseVersion(torch.__version__)

View File

@ -17,6 +17,7 @@
import torch
from .subsampling import Conv2dSubsampling, VggSubsampling

View File

@ -18,6 +18,7 @@
import torch
from torch.nn.utils.rnn import pad_sequence
from .transformer import (
Transformer,
add_eos,

View File

@ -38,15 +38,12 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from .transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -63,6 +60,10 @@ from icefall.utils import (
str2bool,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -19,9 +19,10 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from .label_smoothing import LabelSmoothingLoss
from .subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]

View File

@ -22,6 +22,8 @@ import warnings
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from .scaling import (
ActivationBalancer,
BasicNorm,
@ -29,9 +31,7 @@ from .scaling import (
ScaledConv1d,
ScaledLinear,
)
from torch import Tensor, nn
from .subsampling import Conv2dSubsampling
from .transformer import Supervisions, Transformer, encoder_padding_mask

View File

@ -28,17 +28,14 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.decode import (
get_lattice,
nbest_decoding,
@ -62,6 +59,9 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -47,7 +47,6 @@ import logging
from pathlib import Path
import torch
from .decode import get_params
from icefall.checkpoint import (
average_checkpoints,
@ -55,10 +54,11 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from .conformer import Conformer
from icefall.utils import str2bool
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from .conformer import Conformer
from .decode import get_params
def get_parser():

View File

@ -17,6 +17,8 @@
# limitations under the License.
import torch
from torch import nn
from .scaling import (
ActivationBalancer,
BasicNorm,
@ -24,7 +26,6 @@ from .scaling import (
ScaledConv2d,
ScaledLinear,
)
from torch import nn
class Conv2dSubsampling(nn.Module):

View File

@ -57,19 +57,16 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from .optim import Eden, Eve, LRScheduler
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall import diagnostics
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@ -88,6 +85,10 @@ from icefall.utils import (
str2bool,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]

View File

@ -21,19 +21,18 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from .label_smoothing import LabelSmoothingLoss
from .subsampling import Conv2dSubsampling
from .attention import MultiheadAttention
from torch.nn.utils.rnn import pad_sequence
from .attention import MultiheadAttention
from .label_smoothing import LabelSmoothingLoss
from .scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledLinear,
ScaledEmbedding,
ScaledLinear,
)
from .subsampling import Conv2dSubsampling
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]

View File

@ -22,6 +22,7 @@ from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from .transformer import Supervisions, Transformer, encoder_padding_mask

View File

@ -26,8 +26,6 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -50,6 +48,9 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -1,9 +1,9 @@
#!/usr/bin/env python3
from .subsampling import Conv2dSubsampling
from .subsampling import VggSubsampling
import torch
from .subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling():
N = 3

View File

@ -1,17 +1,17 @@
#!/usr/bin/env python3
import torch
from torch.nn.utils.rnn import pad_sequence
from .transformer import (
Transformer,
add_eos,
add_sos,
decoder_padding_mask,
encoder_padding_mask,
generate_square_subsequent_mask,
decoder_padding_mask,
add_sos,
add_eos,
)
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask():
supervisions = {

View File

@ -28,31 +28,23 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from .transformer import Noam
from icefall.ali import (
convert_alignments_to_tensor,
load_alignments,
lookup_alignments,
)
from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.mmi import LFMMILoss
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from .transformer import Noam
def get_parser():

View File

@ -28,31 +28,23 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from .transformer import Noam
from icefall.ali import (
convert_alignments_to_tensor,
load_alignments,
lookup_alignments,
)
from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.mmi import LFMMILoss
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from .transformer import Noam
def get_parser():

View File

@ -20,9 +20,10 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from .subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
from .subsampling import Conv2dSubsampling, VggSubsampling
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]

View File

@ -80,15 +80,6 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from .asr_datamodule import LibriSpeechAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -104,6 +95,16 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
LOG_EPS = math.log(1e-10)

View File

@ -23,6 +23,9 @@ from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from icefall.utils import make_pad_mask
from .encoder_interface import EncoderInterface
from .scaling import (
ActivationBalancer,
@ -33,9 +36,6 @@ from .scaling import (
ScaledLinear,
)
from icefall.utils import make_pad_mask
LOG_EPSILON = math.log(1e-10)

View File

@ -64,7 +64,6 @@ from pathlib import Path
import sentencepiece as spm
import torch
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -74,6 +73,8 @@ from icefall.checkpoint import (
)
from icefall.utils import str2bool
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -20,10 +20,11 @@ from typing import List, Optional, Tuple
import k2
import torch
from .beam_search import Hypothesis, HypothesisList
from icefall.utils import AttributeDict
from .beam_search import Hypothesis, HypothesisList
class Stream(object):
def __init__(

View File

@ -74,18 +74,13 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
from lhotse import CutSet
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from .asr_datamodule import LibriSpeechAsrDataModule
from .beam_search import Hypothesis, HypothesisList, get_hyps_shape
from .emformer import LOG_EPSILON, stack_states, unstack_states
from kaldifeat import Fbank, FbankOptions
from .stream import Stream
from lhotse import CutSet
from torch.nn.utils.rnn import pad_sequence
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -103,6 +98,12 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .beam_search import Hypothesis, HypothesisList, get_hyps_shape
from .emformer import LOG_EPSILON, stack_states, unstack_states
from .stream import Stream
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -19,6 +19,7 @@
import torch
from .emformer import ConvolutionModule, Emformer, stack_states, unstack_states

View File

@ -69,15 +69,9 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from .asr_datamodule import LibriSpeechAsrDataModule
from .decoder import Decoder
from .emformer import Emformer
from .joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -94,6 +88,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import LibriSpeechAsrDataModule
from .decoder import Decoder
from .emformer import Emformer
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]

View File

@ -80,15 +80,6 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from .asr_datamodule import LibriSpeechAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -104,6 +95,16 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
LOG_EPS = math.log(1e-10)

View File

@ -23,6 +23,9 @@ from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from icefall.utils import make_pad_mask
from .encoder_interface import EncoderInterface
from .scaling import (
ActivationBalancer,
@ -33,9 +36,6 @@ from .scaling import (
ScaledLinear,
)
from icefall.utils import make_pad_mask
LOG_EPSILON = math.log(1e-10)

View File

@ -64,7 +64,6 @@ from pathlib import Path
import sentencepiece as spm
import torch
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -74,6 +73,8 @@ from icefall.checkpoint import (
)
from icefall.utils import str2bool
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -74,18 +74,13 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
from lhotse import CutSet
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from .asr_datamodule import LibriSpeechAsrDataModule
from .beam_search import Hypothesis, HypothesisList, get_hyps_shape
from .emformer import LOG_EPSILON, stack_states, unstack_states
from kaldifeat import Fbank, FbankOptions
from .stream import Stream
from lhotse import CutSet
from torch.nn.utils.rnn import pad_sequence
from .train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -103,6 +98,12 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .beam_search import Hypothesis, HypothesisList, get_hyps_shape
from .emformer import LOG_EPSILON, stack_states, unstack_states
from .stream import Stream
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -19,6 +19,7 @@
import torch
from .emformer import ConvolutionModule, Emformer, stack_states, unstack_states

Some files were not shown because too many files have changed in this diff Show More