from local
This commit is contained in:
parent
db5f76ce67
commit
2a63e6bcc3
Binary file not shown.
@ -534,22 +534,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
def to_int_tuple(s: str):
|
def to_int_tuple(s: str):
|
||||||
return tuple(map(int, s.split(",")))
|
return tuple(map(int, s.split(",")))
|
||||||
|
|
||||||
'''
|
|
||||||
encoder = Zipformer(
|
|
||||||
num_features=params.feature_dim,
|
|
||||||
output_downsampling_factor=2,
|
|
||||||
zipformer_downsampling_factors=to_int_tuple(
|
|
||||||
params.zipformer_downsampling_factors
|
|
||||||
),
|
|
||||||
encoder_dims=to_int_tuple(params.encoder_dims),
|
|
||||||
attention_dim=to_int_tuple(params.attention_dims),
|
|
||||||
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
|
|
||||||
nhead=to_int_tuple(params.nhead),
|
|
||||||
feedforward_dim=to_int_tuple(params.feedforward_dims),
|
|
||||||
cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
|
|
||||||
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
|
|
||||||
)
|
|
||||||
'''
|
|
||||||
if params.encoder_type == 'd2v':
|
if params.encoder_type == 'd2v':
|
||||||
encoder = FairSeqData2VecEncoder(
|
encoder = FairSeqData2VecEncoder(
|
||||||
input_size=params.encoder_dim,
|
input_size=params.encoder_dim,
|
||||||
@ -558,6 +542,21 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
freeze_finetune_updates=params.freeze_finetune_updates*params.accum_grads,
|
freeze_finetune_updates=params.freeze_finetune_updates*params.accum_grads,
|
||||||
additional_block=params.additional_block,
|
additional_block=params.additional_block,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
encoder = Zipformer(
|
||||||
|
num_features=params.feature_dim,
|
||||||
|
output_downsampling_factor=2,
|
||||||
|
zipformer_downsampling_factors=to_int_tuple(
|
||||||
|
params.zipformer_downsampling_factors
|
||||||
|
),
|
||||||
|
encoder_dims=to_int_tuple(params.encoder_dims),
|
||||||
|
attention_dim=to_int_tuple(params.attention_dims),
|
||||||
|
encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
|
||||||
|
nhead=to_int_tuple(params.nhead),
|
||||||
|
feedforward_dim=to_int_tuple(params.feedforward_dims),
|
||||||
|
cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
|
||||||
|
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
|
||||||
|
)
|
||||||
|
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user