Randomly combining output from different transformer encoder layers.

This commit is contained in:
Fangjun Kuang 2022-03-25 17:39:57 +08:00
parent 12de88043a
commit aecb6dce71
4 changed files with 49 additions and 14 deletions

View File

@ -74,7 +74,7 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from train import get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -197,6 +197,8 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
add_model_arguments(parser)
return parser return parser

View File

@ -49,7 +49,7 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from train import get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool from icefall.utils import str2bool
@ -109,6 +109,8 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
add_model_arguments(parser)
return parser return parser

View File

@ -57,7 +57,7 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
def get_parser(): def get_parser():
@ -133,6 +133,8 @@ def get_parser():
""", """,
) )
add_model_arguments(parser)
return parser return parser

View File

@ -52,7 +52,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
@ -73,6 +72,44 @@ from icefall.utils import (
) )
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
type=int,
default=12,
help="Number of transformer encoder layers",
)
parser.add_argument(
"--nhead",
type=int,
default=8,
help="Number of attention heads in a transformer encoder layer",
)
parser.add_argument(
"--dim-feedfoward",
type=int,
default=2048,
help="Feedforward dimension of linear layers after attention in "
"the transformer model",
)
parser.add_argument(
"--attention-dim",
type=int,
default=512,
help="Attention dimension in a transformer encoder layer",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=512,
help="Embedding dimension for the decoder network",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -229,6 +266,8 @@ def get_parser():
help="Accumulate stats on activations, print them and exit.", help="Accumulate stats on activations, print them and exit.",
) )
add_model_arguments(parser)
return parser return parser
@ -270,10 +309,6 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warm_step for Noam optimizer. - warm_step: The warm_step for Noam optimizer.
""" """
params = AttributeDict( params = AttributeDict(
@ -290,13 +325,7 @@ def get_params() -> AttributeDict:
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
# parameters for Noam # parameters for Noam
"warm_step": 80000, # For the 100h subset, use 30000 "warm_step": 80000, # For the 100h subset, use 30000
"env_info": get_env_info(), "env_info": get_env_info(),