minor fixes (#1345)

This commit is contained in:
zr_jin 2023-10-27 13:35:43 +08:00 committed by GitHub
parent 800bf4b6a2
commit ea78b32857
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -116,7 +116,7 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
@ -695,7 +695,7 @@ def main():
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:

View File

@ -586,7 +586,7 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module:
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
@ -1083,7 +1083,7 @@ def run(rank, world_size, args):
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")