mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
do some changes for export.py (#437)
This commit is contained in:
parent
a42d96dfe0
commit
998091ef52
@ -114,8 +114,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -155,6 +153,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -184,8 +184,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -225,6 +223,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -182,8 +182,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
|
|
||||||
assert args.jit is False, "torchscript support will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -223,6 +221,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -182,8 +182,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
|
|
||||||
assert args.jit is False, "torchscript support will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -223,6 +221,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -149,8 +149,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -252,6 +250,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -114,8 +114,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -155,6 +153,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -131,8 +131,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -191,6 +189,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -130,8 +130,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -178,6 +176,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -117,8 +117,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -161,6 +159,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
|
||||||
# Mingshuang Luo)
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -185,8 +185,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -229,6 +227,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -114,8 +114,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -155,6 +153,11 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user