from local

This commit is contained in:
dohe0342 2023-01-09 20:29:10 +09:00
parent b6888bd0de
commit b266758803
2 changed files with 17 additions and 0 deletions

View File

@ -444,6 +444,23 @@ def get_params() -> AttributeDict:
return params
def get_tempencoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Tempformer(
num_features=params.feature_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.encoder_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=params.causal_convolution,
)
return encoder
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(