This commit is contained in:
Erwan 2022-06-23 13:18:23 +02:00
parent 73a281725c
commit dc25ab909a
3 changed files with 26 additions and 11 deletions

View File

@ -30,11 +30,11 @@ import math
from pathlib import Path from pathlib import Path
import torch import torch
from rnn_lm.dataset import get_dataloader from dataset import get_dataloader
from rnn_lm.model import RnnLmModel from model import RnnLmModel
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, load_averaged_model, setup_logger from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser(): def get_parser():
@ -95,10 +95,19 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--num-layers", "--num-layers",
type=int, type=int,
default=4, default=3,
help="Number of RNN layers the model", help="Number of RNN layers the model",
) )
parser.add_argument(
"--tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument( parser.add_argument(
"--batch-size", "--batch-size",
type=int, type=int,
@ -161,16 +170,22 @@ def main():
embedding_dim=params.embedding_dim, embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim, hidden_dim=params.hidden_dim,
num_layers=params.num_layers, num_layers=params.num_layers,
tie_weights=params.tie_weights,
) )
if params.avg == 1: if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
else: else:
model = load_averaged_model( start = params.epoch - params.avg + 1
params.exp_dir, model, params.epoch, params.avg, device filenames = []
) for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
num_param_requires_grad = sum( num_param_requires_grad = sum(

View File

@ -83,7 +83,7 @@ def get_parser():
"--tie-weights", "--tie-weights",
type=str2bool, type=str2bool,
default=True, default=True,
help="""True share the weights between the input embedding layer and the help="""True to share the weights between the input embedding layer and the
last output linear layer last output linear layer
""", """,
) )

View File

@ -166,7 +166,7 @@ def get_parser():
"--tie-weights", "--tie-weights",
type=str2bool, type=str2bool,
default=False, default=False,
help="""True share the weights between the input embedding layer and the help="""True to share the weights between the input embedding layer and the
last output linear layer last output linear layer
""", """,
) )