Merge da98aa1b8ea4a7e318b0c91702724cc853bb08ff into 8e6fd97c6b92826b2e13caa1a683f2eb4bf9a832

This commit is contained in:
Fangjun Kuang 2022-01-25 08:04:37 +01:00 committed by GitHub
commit bff66ce86a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 6 deletions

View File

@ -51,6 +51,7 @@ from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from model import Transducer from model import Transducer
from transformer import Transformer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
@ -128,6 +129,13 @@ def get_parser():
help="Maximum number of symbols per frame", help="Maximum number of symbols per frame",
) )
parser.add_argument(
"--encoder-type",
type=str,
default="conformer",
help="Type of the encoder. Valid values are: conformer and transformer",
)
return parser return parser
@ -150,8 +158,16 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer if params.encoder_type == "conformer":
encoder = Conformer( Encoder = Conformer
elif params.encoder_type == "transformer":
Encoder = Transformer
else:
raise ValueError(
f"Unsupported encoder type: {params.encoder_type}"
"\nPlease use conformer or transformer"
)
encoder = Encoder(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,

View File

@ -54,7 +54,7 @@ from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam, Transformer
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
@ -138,6 +138,13 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--encoder-type",
type=str,
default="conformer",
help="Type of the encoder. Valid values are: conformer and transformer",
)
return parser return parser
@ -214,8 +221,16 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer if params.encoder_type == "conformer":
encoder = Conformer( Encoder = Conformer
elif params.encoder_type == "transformer":
Encoder = Transformer
else:
raise ValueError(
f"Unsupported encoder type: {params.encoder_type}"
"\nPlease use conformer or transformer"
)
encoder = Encoder(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
@ -360,7 +375,8 @@ def compute_loss(
params: params:
Parameters for training. See :func:`get_params`. Parameters for training. See :func:`get_params`.
model: model:
The model for training. It is an instance of Conformer in our case. The model for training. It is an instance of Conformer or Transformer
in our case.
batch: batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it. for the content in it.