do some changes for export.py (#437)

This commit is contained in:
Mingshuang Luo 2022-06-20 14:57:08 +08:00 committed by GitHub
parent a42d96dfe0
commit 998091ef52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 56 additions and 23 deletions

View File

@ -114,8 +114,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -155,6 +153,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -184,8 +184,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -225,6 +223,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -182,8 +182,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
assert args.jit is False, "torchscript support will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -223,6 +221,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -182,8 +182,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
assert args.jit is False, "torchscript support will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -223,6 +221,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -149,8 +149,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -252,6 +250,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -114,8 +114,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -155,6 +153,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -131,8 +131,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -191,6 +189,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -130,8 +130,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -178,6 +176,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -117,8 +117,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -161,6 +159,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
# Mingshuang Luo) # Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -185,8 +185,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -229,6 +227,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -114,8 +114,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -155,6 +153,11 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"