from local

This commit is contained in:
dohe0342 2023-02-02 17:34:52 +09:00
parent 39ba7022f9
commit 18028ee995
3 changed files with 10 additions and 1 deletions

View File

@ -505,7 +505,8 @@ def main():
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
'''
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
@ -517,6 +518,14 @@ def main():
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
'''
model = Transformer(
num_features=params.feature_dim,
num_classes=num_classes,
use_feat_batchnorm=params.use_feat_batchnorm,
num_encoder_layers=params.num_encoder_layers,
num_decoder_layers=params.num_decoder_layers,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)