mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-30 20:24:18 +00:00
Disable decoder layers in pretrained.py if it is not used.
This commit is contained in:
parent
2de12b195e
commit
b975bdef9e
@ -226,7 +226,13 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
|
if args.method != "attention-decoder":
|
||||||
|
# to save memory as the attention decoder
|
||||||
|
# will not be used
|
||||||
|
params.num_decoder_layers = 0
|
||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -248,7 +254,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user