mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 06:04:18 +00:00
Randomly combining output from different transformer encoder layers.
This commit is contained in:
parent
12de88043a
commit
aecb6dce71
@ -74,7 +74,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import get_params, get_transducer_model
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -197,6 +197,8 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -49,7 +49,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.utils import str2bool
|
||||
@ -109,6 +109,8 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -57,7 +57,7 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import get_params, get_transducer_model
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -133,6 +133,8 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -52,7 +52,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
@ -73,6 +72,44 @@ from icefall.utils import (
|
||||
)
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Number of transformer encoder layers",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nhead",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of attention heads in a transformer encoder layer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dim-feedfoward",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Feedforward dimension of linear layers after attention in "
|
||||
"the transformer model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Attention dimension in a transformer encoder layer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Embedding dimension for the decoder network",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -229,6 +266,8 @@ def get_parser():
|
||||
help="Accumulate stats on activations, print them and exit.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -270,10 +309,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- attention_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- warm_step: The warm_step for Noam optimizer.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
@ -290,13 +325,7 @@ def get_params() -> AttributeDict:
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
# parameters for decoder
|
||||
"embedding_dim": 512,
|
||||
# parameters for Noam
|
||||
"warm_step": 80000, # For the 100h subset, use 30000
|
||||
"env_info": get_env_info(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user