from local

This commit is contained in:
dohe0342 2023-01-09 22:48:06 +09:00
parent cb5055d5f6
commit 44d0f6e0fb
2 changed files with 5 additions and 3 deletions

View File

@ -975,8 +975,11 @@ def run(rank, world_size, args):
logging.info("About to create model")
transducer_model = get_transducer_model(params)
try: pre_trained_model = torch.load('/workspace/icefall/egs/librispeech/ASR/incremental_transf/conformer_12layers.pt')
except: pre_trained_model = torch.load('/home/work/workspace/icefall/egs/librispeech/ASR/incremental_transf/conformer_12layers.pt')
try:
path = '/workspace/icefall/egs/librispeech/ASR/incremental_transf/conformer_12layers.pt'
load_checkpoint(transducer_model)
except:
path = '/home/work/workspace/icefall/egs/librispeech/ASR/incremental_transf/conformer_12layers.pt'
pre_trained_model = pre_trained_model['model']
transducer_model.load_state_dict(pre_trained_model, strict=True)
transducer_model.to(device)
@ -987,7 +990,6 @@ def run(rank, world_size, args):
p.requires_grad = False
else:
print(n)
exit()
'''
for n, p in model.named_parameters():
if 'layer' not in n: