This commit is contained in:
Erwan 2022-06-23 13:18:23 +02:00
parent 73a281725c
commit dc25ab909a
3 changed files with 26 additions and 11 deletions

View File

@ -30,11 +30,11 @@ import math
from pathlib import Path
import torch
from rnn_lm.dataset import get_dataloader
from rnn_lm.model import RnnLmModel
from dataset import get_dataloader
from model import RnnLmModel
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, load_averaged_model, setup_logger
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
@ -95,10 +95,19 @@ def get_parser():
parser.add_argument(
"--num-layers",
type=int,
default=4,
default=3,
help="Number of RNN layers the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--batch-size",
type=int,
@ -161,16 +170,22 @@ def main():
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
else:
model = load_averaged_model(
params.exp_dir, model, params.epoch, params.avg, device
)
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
num_param_requires_grad = sum(

View File

@ -83,7 +83,7 @@ def get_parser():
"--tie-weights",
type=str2bool,
default=True,
help="""True share the weights between the input embedding layer and the
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)

View File

@ -166,7 +166,7 @@ def get_parser():
"--tie-weights",
type=str2bool,
default=False,
help="""True share the weights between the input embedding layer and the
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)