mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix help
This commit is contained in:
parent
73a281725c
commit
dc25ab909a
@ -30,11 +30,11 @@ import math
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from rnn_lm.dataset import get_dataloader
|
from dataset import get_dataloader
|
||||||
from rnn_lm.model import RnnLmModel
|
from model import RnnLmModel
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.utils import AttributeDict, load_averaged_model, setup_logger
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -95,10 +95,19 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-layers",
|
"--num-layers",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=3,
|
||||||
help="Number of RNN layers the model",
|
help="Number of RNN layers the model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tie-weights",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to share the weights between the input embedding layer and the
|
||||||
|
last output linear layer
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch-size",
|
"--batch-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -161,16 +170,22 @@ def main():
|
|||||||
embedding_dim=params.embedding_dim,
|
embedding_dim=params.embedding_dim,
|
||||||
hidden_dim=params.hidden_dim,
|
hidden_dim=params.hidden_dim,
|
||||||
num_layers=params.num_layers,
|
num_layers=params.num_layers,
|
||||||
|
tie_weights=params.tie_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.avg == 1:
|
if params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
model.to(device)
|
||||||
else:
|
else:
|
||||||
model = load_averaged_model(
|
start = params.epoch - params.avg + 1
|
||||||
params.exp_dir, model, params.epoch, params.avg, device
|
filenames = []
|
||||||
)
|
for i in range(start, params.epoch + 1):
|
||||||
|
if start >= 0:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
num_param_requires_grad = sum(
|
num_param_requires_grad = sum(
|
||||||
|
@ -83,7 +83,7 @@ def get_parser():
|
|||||||
"--tie-weights",
|
"--tie-weights",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="""True share the weights between the input embedding layer and the
|
help="""True to share the weights between the input embedding layer and the
|
||||||
last output linear layer
|
last output linear layer
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
@ -166,7 +166,7 @@ def get_parser():
|
|||||||
"--tie-weights",
|
"--tie-weights",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""True share the weights between the input embedding layer and the
|
help="""True to share the weights between the input embedding layer and the
|
||||||
last output linear layer
|
last output linear layer
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user