from local

This commit is contained in:
dohe0342 2023-04-11 16:25:21 +09:00
parent 1147d1afa9
commit dcfd8fb6a8
2 changed files with 1 additions and 4 deletions

View File

@ -1560,7 +1560,7 @@ def run_adapter(rank, world_size, args, wb=None):
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params, prompt=True)
num_param = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()]) num_param = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -1618,8 +1618,6 @@ def run_adapter(rank, world_size, args, wb=None):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
prompt = prompt.to('cuda')
''' '''
if params.hpo: if params.hpo:
train_cuts = librispeech.train_clean_10_cuts(option=params.gender) train_cuts = librispeech.train_clean_10_cuts(option=params.gender)
@ -1689,7 +1687,6 @@ def run_adapter(rank, world_size, args, wb=None):
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
wb=wb, wb=wb,
prompt=prompt,
) )
if params.print_diagnostics: if params.print_diagnostics: