diff --git a/egs/librispeech/ASR/incremental_transf/.train.py.swp b/egs/librispeech/ASR/incremental_transf/.train.py.swp index b5a1e3ce7..3cebdc975 100644 Binary files a/egs/librispeech/ASR/incremental_transf/.train.py.swp and b/egs/librispeech/ASR/incremental_transf/.train.py.swp differ diff --git a/egs/librispeech/ASR/incremental_transf/train.py b/egs/librispeech/ASR/incremental_transf/train.py index 10b3ad142..c852b9951 100755 --- a/egs/librispeech/ASR/incremental_transf/train.py +++ b/egs/librispeech/ASR/incremental_transf/train.py @@ -976,7 +976,24 @@ def run(rank, world_size, args): for n, p in model.named_parameters(): if 'layer' not in n: + try: p.data = pre_trained_model2[n] + except: print(f'pre-trained model has no parameter named {n}.') + else: layer_name_splited = n.split('.') + if int(layer_name_splited[3]) % 2 == 0: + layer_name_splited[0] = 'pt_encoder' + layer_name_splited[3] = str(int(layer_name_splited[3])//2) + old_name = '.'.join(layer_name_splited) + try: p.data = pre_trained_model[old_name] + except: print(f'pre-trained model has no parameter named {n}.') + else: + layer_name_splited[0] = 'inter_encoder' + layer_name_splited[3] = str(int(layer_name_splited[3])//2) + old_name = '.'.join(layer_name_splited) + try: p.data = pre_trained_model[old_name] + except: print(f'pre-trained model has no parameter named {n}.') + + num_param = sum([p.numel() for p in model.parameters()])