Merge 561c428e508da15074afb4b3496f5f6aafbe65d5 into f24b76e64bb59e157e3904d0330a132a308f18c0

This commit is contained in:
Shayne Mei 2022-08-07 12:00:32 +08:00 committed by GitHub
commit f7c3d1ee28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
280 changed files with 1265 additions and 1145 deletions

View File

@ -39,7 +39,8 @@ from typing import Dict, List
import k2
import torch
from prepare_lang import (
from .prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,

View File

@ -22,7 +22,8 @@ import os
import tempfile
import k2
from prepare_lang import (
from .prepare_lang import (
add_disambig_symbols,
generate_id_map,
get_phones,

View File

@ -59,21 +59,8 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import Aidatatang_200zhAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -82,6 +69,16 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import Aidatatang_200zhAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -46,12 +46,13 @@ import logging
from pathlib import Path
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from .train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -63,17 +63,18 @@ import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.lexicon import Lexicon
from .train import get_params, get_transducer_model
def get_parser():

View File

@ -53,19 +53,12 @@ from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import Aidatatang_200zhAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -81,8 +74,15 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import Aidatatang_200zhAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

View File

@ -22,7 +22,8 @@ from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
from .transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer):

View File

@ -26,8 +26,6 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from conformer import Conformer
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -48,6 +46,9 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -24,12 +24,13 @@ import logging
from pathlib import Path
import torch
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -26,7 +26,6 @@ import k2
import kaldifeat
import torch
import torchaudio
from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import (
@ -36,6 +35,8 @@ from icefall.decode import (
)
from icefall.utils import AttributeDict, get_texts
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -16,10 +16,10 @@
# limitations under the License.
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch
from .subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling():
N = 3

View File

@ -18,7 +18,8 @@
import torch
from torch.nn.utils.rnn import pad_sequence
from transformer import (
from .transformer import (
Transformer,
add_eos,
add_sos,

View File

@ -26,13 +26,10 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -48,6 +45,10 @@ from icefall.utils import (
str2bool,
)
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -20,10 +20,11 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
from .label_smoothing import LabelSmoothingLoss
from .subsampling import Conv2dSubsampling, VggSubsampling
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]

View File

@ -22,7 +22,8 @@ from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
from .transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer):

View File

@ -26,8 +26,6 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import (
@ -49,6 +47,9 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -28,13 +28,10 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
@ -51,6 +48,10 @@ from icefall.utils import (
str2bool,
)
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -20,10 +20,11 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
from .label_smoothing import LabelSmoothingLoss
from .subsampling import Conv2dSubsampling, VggSubsampling
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]

View File

@ -39,7 +39,8 @@ from typing import Dict, List
import k2
import torch
from prepare_lang import (
from .prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,

View File

@ -22,7 +22,8 @@ import os
import tempfile
import k2
from prepare_lang import (
from .prepare_lang import (
add_disambig_symbols,
generate_id_map,
get_phones,

View File

@ -66,16 +66,6 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from aishell import AIShell
from asr_datamodule import AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -92,6 +82,17 @@ from icefall.utils import (
write_error_stats,
)
from .aishell import AIShell
from .asr_datamodule import AsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -48,7 +48,6 @@ import logging
from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -59,6 +58,8 @@ from icefall.checkpoint import (
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -20,11 +20,12 @@ from typing import Optional
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
from .encoder_interface import EncoderInterface
from .scaling import ScaledLinear
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf

View File

@ -65,17 +65,18 @@ import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():

View File

@ -58,23 +58,13 @@ from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from aidatatang_200zh import AIDatatang200zh
from aishell import AIShell
from asr_datamodule import AsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -93,8 +83,17 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .aidatatang_200zh import AIDatatang200zh
from .aishell import AIShell
from .asr_datamodule import AsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]

View File

@ -24,8 +24,6 @@ from typing import Dict, List, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from model import TdnnLstm
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice, nbest_decoding, one_best_decoding
@ -39,6 +37,9 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import AishellAsrDataModule
from .model import TdnnLstm
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -25,12 +25,13 @@ import k2
import kaldifeat
import torch
import torchaudio
from model import TdnnLstm
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
from .model import TdnnLstm
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -36,9 +36,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from asr_datamodule import AishellAsrDataModule
from lhotse.utils import fix_random_seed
from model import TdnnLstm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
@ -49,12 +47,10 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
from .asr_datamodule import AishellAsrDataModule
from .model import TdnnLstm
def get_parser():

View File

@ -19,7 +19,8 @@ from typing import Dict, List, Optional
import numpy as np
import torch
from model import Transducer
from .model import Transducer
def greedy_search(

View File

@ -22,10 +22,11 @@ from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from transformer import Transformer
from icefall.utils import make_pad_mask
from .transformer import Transformer
class Conformer(Transformer):
"""

View File

@ -24,12 +24,6 @@ from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from beam_search import beam_search, greedy_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
@ -38,10 +32,17 @@ from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
str2bool,
write_error_stats,
)
from .asr_datamodule import AishellAsrDataModule
from .beam_search import beam_search, greedy_search
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -49,16 +49,17 @@ from pathlib import Path
import torch
import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -17,10 +17,11 @@
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
from .encoder_interface import EncoderInterface
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf

View File

@ -51,11 +51,6 @@ import kaldifeat
import torch
import torch.nn as nn
import torchaudio
from beam_search import beam_search, greedy_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
@ -63,6 +58,12 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
from .beam_search import beam_search, greedy_search
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -23,7 +23,8 @@ To run this file, do:
"""
import torch
from decoder import Decoder
from .decoder import Decoder
def test_decoder():

View File

@ -30,18 +30,12 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -51,6 +45,13 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -20,11 +20,12 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import make_pad_mask
from .encoder_interface import EncoderInterface
from .subsampling import Conv2dSubsampling, VggSubsampling
class Transformer(EncoderInterface):
def __init__(

View File

@ -29,10 +29,7 @@ from lhotse.dataset import (
K2SpeechRecognitionDataset,
SpecAugment,
)
from lhotse.dataset.input_strategies import (
OnTheFlyFeatures,
PrecomputedFeatures,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
from torch.utils.data import DataLoader
from icefall.utils import str2bool

View File

@ -63,16 +63,6 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from aishell import AIShell
from asr_datamodule import AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
@ -83,6 +73,17 @@ from icefall.utils import (
write_error_stats,
)
from .aishell import AIShell
from .asr_datamodule import AsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -48,16 +48,17 @@ from pathlib import Path
import torch
import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -20,10 +20,11 @@ from typing import Optional
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
from .encoder_interface import EncoderInterface
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf

View File

@ -65,17 +65,18 @@ import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.lexicon import Lexicon
from .train import get_params, get_transducer_model
def get_parser():

View File

@ -50,21 +50,13 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from aidatatang_200zh import AIDatatang200zh
from aishell import AIShell
from asr_datamodule import AsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -74,6 +66,15 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .aidatatang_200zh import AIDatatang200zh
from .aishell import AIShell
from .asr_datamodule import AsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -65,15 +65,6 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
@ -84,6 +75,16 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import AishellAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -48,16 +48,17 @@ from pathlib import Path
import torch
import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -65,17 +65,18 @@ import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.lexicon import Lexicon
from .train import get_params, get_transducer_model
def get_parser():

View File

@ -23,7 +23,8 @@ To run this file, do:
"""
import torch
from decoder import Decoder
from .decoder import Decoder
def test_decoder():

View File

@ -46,18 +46,12 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -67,6 +61,13 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import AishellAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -110,18 +110,6 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AiShell2AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
@ -139,6 +127,19 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import AiShell2AsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -49,7 +49,6 @@ import logging
from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -60,6 +59,8 @@ from icefall.checkpoint import (
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -59,17 +59,18 @@ import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():

View File

@ -61,19 +61,12 @@ from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AiShell2AsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -92,8 +85,15 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import AiShell2AsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]

View File

@ -39,7 +39,8 @@ from typing import Dict, List
import k2
import torch
from prepare_lang import (
from .prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,

View File

@ -22,7 +22,8 @@ import os
import tempfile
import k2
from prepare_lang import (
from .prepare_lang import (
add_disambig_symbols,
generate_id_map,
get_phones,

View File

@ -61,17 +61,7 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import Aishell4AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from local.text_normalize import text_normalize
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -88,6 +78,17 @@ from icefall.utils import (
write_error_stats,
)
from ..local.text_normalize import text_normalize
from .asr_datamodule import Aishell4AsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -49,7 +49,6 @@ import logging
from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -60,6 +59,8 @@ from icefall.checkpoint import (
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -72,17 +72,18 @@ import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():

View File

@ -23,7 +23,7 @@ To run this file, do:
python ./pruned_transducer_stateless5/test_model.py
"""
from train import get_params, get_transducer_model
from .train import get_params, get_transducer_model
def test_model_1():

View File

@ -53,20 +53,12 @@ from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import Aishell4AsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from local.text_normalize import text_normalize
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -85,8 +77,16 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from ..local.text_normalize import text_normalize
from .asr_datamodule import Aishell4AsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]

View File

@ -39,7 +39,8 @@ from typing import Dict, List
import k2
import torch
from prepare_lang import (
from .prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,

View File

@ -22,7 +22,8 @@ import os
import tempfile
import k2
from prepare_lang import (
from .prepare_lang import (
add_disambig_symbols,
generate_id_map,
get_phones,

View File

@ -59,22 +59,9 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AlimeetingAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -83,6 +70,16 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import AlimeetingAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -46,12 +46,13 @@ import logging
from pathlib import Path
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from .train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -63,17 +63,18 @@ import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.lexicon import Lexicon
from .train import get_params, get_transducer_model
def get_parser():

View File

@ -53,19 +53,12 @@ from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AlimeetingAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -81,8 +74,15 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import AlimeetingAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

View File

@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union
import torch
from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
from .transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer):

View File

@ -27,9 +27,6 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from conformer import Conformer
from gigaspeech_scoring import asr_text_post_processing
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -52,6 +49,10 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import GigaSpeechAsrDataModule
from .conformer import Conformer
from .gigaspeech_scoring import asr_text_post_processing
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -27,14 +27,11 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -50,6 +47,10 @@ from icefall.utils import (
str2bool,
)
from .asr_datamodule import GigaSpeechAsrDataModule
from .conformer import Conformer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -19,10 +19,11 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
from .label_smoothing import LabelSmoothingLoss
from .subsampling import Conv2dSubsampling, VggSubsampling
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]

View File

@ -66,22 +66,8 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from gigaspeech_scoring import asr_text_post_processing
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import (
AttributeDict,
setup_logger,
@ -89,6 +75,17 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import GigaSpeechAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .gigaspeech_scoring import asr_text_post_processing
from .train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -49,15 +49,12 @@ from pathlib import Path
import sentencepiece as spm
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool
from .train import get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -56,14 +56,8 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -77,6 +71,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from .asr_datamodule import GigaSpeechAsrDataModule
from .conformer import Conformer
from .decoder import Decoder
from .joiner import Joiner
from .model import Transducer
from .optim import Eden, Eve
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]

View File

@ -34,8 +34,6 @@ from pathlib import Path
import k2
import numpy as np
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse import CutSet
from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
@ -51,6 +49,9 @@ from icefall.utils import (
setup_logger,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union
import torch
from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
from .transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer):

View File

@ -26,8 +26,6 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -54,6 +52,9 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -24,12 +24,13 @@ import logging
from pathlib import Path
import torch
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -27,7 +27,6 @@ import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import (
@ -38,6 +37,8 @@ from icefall.decode import (
)
from icefall.utils import AttributeDict, get_texts
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -18,7 +18,8 @@
from distutils.version import LooseVersion
import torch
from label_smoothing import LabelSmoothingLoss
from .label_smoothing import LabelSmoothingLoss
torch_ver = LooseVersion(torch.__version__)

View File

@ -17,7 +17,8 @@
import torch
from subsampling import Conv2dSubsampling, VggSubsampling
from .subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling():

View File

@ -18,7 +18,8 @@
import torch
from torch.nn.utils.rnn import pad_sequence
from transformer import (
from .transformer import (
Transformer,
add_eos,
add_sos,

View File

@ -38,15 +38,12 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -63,6 +60,10 @@ from icefall.utils import (
str2bool,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from .transformer import Noam
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -19,10 +19,11 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
from .label_smoothing import LabelSmoothingLoss
from .subsampling import Conv2dSubsampling, VggSubsampling
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]

View File

@ -21,7 +21,7 @@ import torch.nn as nn
from torch import Tensor
from torch.nn.init import xavier_normal_
from scaling import ScaledLinear
from .scaling import ScaledLinear
class MultiheadAttention(nn.Module):

View File

@ -22,17 +22,17 @@ import warnings
from typing import Optional, Tuple
import torch
from scaling import (
from torch import Tensor, nn
from .scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv1d,
ScaledLinear,
)
from torch import Tensor, nn
from subsampling import Conv2dSubsampling
from transformer import Supervisions, Transformer, encoder_padding_mask
from .subsampling import Conv2dSubsampling
from .transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer):

View File

@ -28,17 +28,14 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.decode import (
get_lattice,
nbest_decoding,
@ -62,6 +59,9 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -47,7 +47,6 @@ import logging
from pathlib import Path
import torch
from decode import get_params
from icefall.checkpoint import (
average_checkpoints,
@ -55,10 +54,11 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from conformer import Conformer
from icefall.utils import str2bool
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from .conformer import Conformer
from .decode import get_params
def get_parser():

View File

@ -17,14 +17,15 @@
# limitations under the License.
import torch
from scaling import (
from torch import nn
from .scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv2d,
ScaledLinear,
)
from torch import nn
class Conv2dSubsampling(nn.Module):

View File

@ -54,23 +54,19 @@ from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall import diagnostics
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@ -89,8 +85,12 @@ from icefall.utils import (
str2bool,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from .optim import Eden, Eve, LRScheduler
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
torch.optim.lr_scheduler._LRScheduler, LRScheduler
]

View File

@ -21,19 +21,18 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling
from attention import MultiheadAttention
from torch.nn.utils.rnn import pad_sequence
from scaling import (
from .attention import MultiheadAttention
from .label_smoothing import LabelSmoothingLoss
from .scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledLinear,
ScaledEmbedding,
ScaledLinear,
)
from .subsampling import Conv2dSubsampling
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]

View File

@ -22,7 +22,8 @@ from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
from .transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer):

View File

@ -26,8 +26,6 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -50,6 +48,9 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -1,9 +1,9 @@
#!/usr/bin/env python3
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch
from .subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling():
N = 3

View File

@ -1,17 +1,17 @@
#!/usr/bin/env python3
import torch
from transformer import (
from torch.nn.utils.rnn import pad_sequence
from .transformer import (
Transformer,
add_eos,
add_sos,
decoder_padding_mask,
encoder_padding_mask,
generate_square_subsequent_mask,
decoder_padding_mask,
add_sos,
add_eos,
)
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask():
supervisions = {

View File

@ -28,31 +28,23 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.ali import (
convert_alignments_to_tensor,
load_alignments,
lookup_alignments,
)
from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.mmi import LFMMILoss
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from .transformer import Noam
def get_parser():

View File

@ -28,31 +28,23 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.ali import (
convert_alignments_to_tensor,
load_alignments,
lookup_alignments,
)
from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.mmi import LFMMILoss
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
from .asr_datamodule import LibriSpeechAsrDataModule
from .conformer import Conformer
from .transformer import Noam
def get_parser():

View File

@ -20,9 +20,10 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
from .subsampling import Conv2dSubsampling, VggSubsampling
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]

View File

@ -80,15 +80,6 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -104,6 +95,16 @@ from icefall.utils import (
write_error_stats,
)
from .asr_datamodule import LibriSpeechAsrDataModule
from .beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from .train import add_model_arguments, get_params, get_transducer_model
LOG_EPS = math.log(1e-10)

View File

@ -23,8 +23,11 @@ from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import (
from icefall.utils import make_pad_mask
from .encoder_interface import EncoderInterface
from .scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
@ -33,9 +36,6 @@ from scaling import (
ScaledLinear,
)
from icefall.utils import make_pad_mask
LOG_EPSILON = math.log(1e-10)

View File

@ -64,7 +64,6 @@ from pathlib import Path
import sentencepiece as spm
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -74,6 +73,8 @@ from icefall.checkpoint import (
)
from icefall.utils import str2bool
from .train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(

View File

@ -20,10 +20,11 @@ from typing import List, Optional, Tuple
import k2
import torch
from beam_search import Hypothesis, HypothesisList
from icefall.utils import AttributeDict
from .beam_search import Hypothesis, HypothesisList
class Stream(object):
def __init__(

Some files were not shown because too many files have changed in this diff Show More