Support to choose transformer as encoder.

This commit is contained in:
Fangjun Kuang 2022-01-22 22:24:14 +08:00
parent d6050eb02e
commit ce5670f39e

View File

@ -54,7 +54,7 @@ from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from transformer import Noam, Transformer
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
@ -138,6 +138,13 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--encoder-type",
type=str,
default="conformer",
help="Type of the encoder. Valid values are: conformer and transformer",
)
return parser
@ -214,8 +221,16 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
if params.encoder_type == "conformer":
Encoder = Conformer
elif params.encoder_type == "transformer":
Encoder = Transformer
else:
raise ValueError(
f"Unsupported encoder type: {params.encoder_type}"
"\nPlease use conformer or transformer"
)
encoder = Encoder(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
@ -360,7 +375,8 @@ def compute_loss(
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
The model for training. It is an instance of Conformer or Transformer
in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.