diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.prompt_tuning.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.prompt_tuning.py.swp index 9b9b096ff..946cd7f67 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.prompt_tuning.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.prompt_tuning.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/prompt_tuning.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/prompt_tuning.py index 472a284b4..39820a8aa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/prompt_tuning.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/prompt_tuning.py @@ -666,6 +666,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, + prompt=params.prompt, ) return model @@ -1558,7 +1559,7 @@ def run_adapter(rank, world_size, args, wb=None): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params, prompt=True) + model = get_transducer_model(params) num_param = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")