mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix help
This commit is contained in:
parent
73a281725c
commit
dc25ab909a
@ -30,11 +30,11 @@ import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from rnn_lm.dataset import get_dataloader
|
||||
from rnn_lm.model import RnnLmModel
|
||||
from dataset import get_dataloader
|
||||
from model import RnnLmModel
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict, load_averaged_model, setup_logger
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -95,10 +95,19 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
default=3,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tie-weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
@ -161,16 +170,22 @@ def main():
|
||||
embedding_dim=params.embedding_dim,
|
||||
hidden_dim=params.hidden_dim,
|
||||
num_layers=params.num_layers,
|
||||
tie_weights=params.tie_weights,
|
||||
)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
else:
|
||||
model = load_averaged_model(
|
||||
params.exp_dir, model, params.epoch, params.avg, device
|
||||
)
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
num_param_requires_grad = sum(
|
||||
|
@ -83,7 +83,7 @@ def get_parser():
|
||||
"--tie-weights",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""True share the weights between the input embedding layer and the
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
|
@ -166,7 +166,7 @@ def get_parser():
|
||||
"--tie-weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True share the weights between the input embedding layer and the
|
||||
help="""True to share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user