from local

This commit is contained in:
dohe0342 2022-12-10 14:12:58 +09:00
parent db5f76ce67
commit 2a63e6bcc3
2 changed files with 15 additions and 16 deletions

View File

@ -534,22 +534,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
def to_int_tuple(s: str):
return tuple(map(int, s.split(",")))
'''
encoder = Zipformer(
num_features=params.feature_dim,
output_downsampling_factor=2,
zipformer_downsampling_factors=to_int_tuple(
params.zipformer_downsampling_factors
),
encoder_dims=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims),
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
nhead=to_int_tuple(params.nhead),
feedforward_dim=to_int_tuple(params.feedforward_dims),
cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
)
'''
if params.encoder_type == 'd2v':
encoder = FairSeqData2VecEncoder(
input_size=params.encoder_dim,
@ -558,6 +542,21 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
freeze_finetune_updates=params.freeze_finetune_updates*params.accum_grads,
additional_block=params.additional_block,
)
else:
encoder = Zipformer(
num_features=params.feature_dim,
output_downsampling_factor=2,
zipformer_downsampling_factors=to_int_tuple(
params.zipformer_downsampling_factors
),
encoder_dims=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims),
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
nhead=to_int_tuple(params.nhead),
feedforward_dim=to_int_tuple(params.feedforward_dims),
cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
)
return encoder