[TF saved_model_cli] Allow user to set target_cpu for xla aot compilation.
PiperOrigin-RevId: 312290453 Change-Id: I024e2b3884e436578e351d43961199e4c28307c3
This commit is contained in:
parent
fb7ba76670
commit
e25d9862ca
@ -20,7 +20,7 @@ load(
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags")
|
||||
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
|
||||
|
||||
def tf_library(
|
||||
name,
|
||||
@ -188,7 +188,9 @@ def tf_library(
|
||||
# `find` on such an object.
|
||||
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
|
||||
|
||||
flags = tfcompile_extra_flags() + flags
|
||||
target_cpu = tfcompile_target_cpu()
|
||||
extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " "
|
||||
flags = extra_flags + flags
|
||||
|
||||
if enable_xla_hlo_profiling:
|
||||
profiling_flag = "--xla_hlo_profile"
|
||||
|
@ -215,6 +215,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
|
||||
signature_def_key,
|
||||
cpp_class,
|
||||
target_triple,
|
||||
target_cpu,
|
||||
variables_to_feed=(),
|
||||
enable_multithreading=False):
|
||||
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`.
|
||||
@ -239,6 +240,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
|
||||
signature_def_key: String, the signature_def to use in the SavedModel.
|
||||
cpp_class: String, Name of output C++ class.
|
||||
target_triple: String, LLVM target triple.
|
||||
target_cpu: String, LLVM target cpu name.
|
||||
variables_to_feed: A list of strings, the variables that will be fed by the
|
||||
user; these won't be frozen. If `None`, then we will extract all the
|
||||
variables in the graph and mark them as to-feed. The default behavior is
|
||||
@ -367,6 +369,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
|
||||
config=config_pbtxt_location,
|
||||
cpp_class=cpp_class,
|
||||
target_triple=target_triple,
|
||||
target_cpu=target_cpu,
|
||||
entry_point='entry_{}'.format(entry_digest),
|
||||
out_function_object='{}.o'.format(output_prefix),
|
||||
out_header='{}.h'.format(output_prefix),
|
||||
|
@ -825,6 +825,7 @@ def aot_compile_cpu(args):
|
||||
variables_to_feed=variables_to_feed,
|
||||
output_prefix=args.output_prefix,
|
||||
target_triple=args.target_triple,
|
||||
target_cpu=args.target_cpu,
|
||||
cpp_class=args.cpp_class,
|
||||
enable_multithreading=args.enable_multithreading)
|
||||
|
||||
@ -1096,6 +1097,14 @@ def add_aot_compile_cpu_subparser(subparsers):
|
||||
'x86_64-none-darwin, x86_64-apple-ios, arm64-none-ios, '
|
||||
'armv7-none-android. More examples are available in tfcompile.bzl '
|
||||
'in the tensorflow codebase.'))
|
||||
parser_compile.add_argument(
|
||||
'--target_cpu',
|
||||
type=str,
|
||||
default='',
|
||||
help=('Target cpu name for LLVM during AOT compilation. Examples: '
|
||||
'x86_64, skylake, haswell, westmere, <empty> (unknown). For '
|
||||
'a complete list of options, run (for x86 targets): '
|
||||
'`llc -march=x86 -mcpu=help`'))
|
||||
parser_compile.add_argument(
|
||||
'--checkpoint_path',
|
||||
type=str,
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Definitions for using tools like saved_model_cli."""
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_xla_available")
|
||||
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "target_llvm_triple")
|
||||
|
||||
def _maybe_force_compile(args, force_compile):
|
||||
@ -19,6 +20,7 @@ def saved_model_compile_aot(
|
||||
signature_def = "serving_default",
|
||||
variables_to_feed = "",
|
||||
target_triple = None,
|
||||
target_cpu = None,
|
||||
force_without_xla_support_flag = True,
|
||||
tags = None):
|
||||
"""Compile a SavedModel directory accessible from a filegroup.
|
||||
@ -88,7 +90,9 @@ def saved_model_compile_aot(
|
||||
uninitialized in the compiled object (this applies to all input
|
||||
arguments from the signature as well).
|
||||
target_triple: The LLVM target triple to use (defaults to current build's
|
||||
target architecture's triple).
|
||||
target architecture's triple). Similar to clang's -target flag.
|
||||
target_cpu: The LLVM cpu name used for compilation. Similar to clang's
|
||||
-mcpu flag.
|
||||
force_without_xla_support_flag: Whether to compile even when
|
||||
`--define=with_xla_support=true` is not set. If `False`, and the
|
||||
define is not passed when building, then the created `cc_library`
|
||||
@ -100,6 +104,7 @@ def saved_model_compile_aot(
|
||||
"""
|
||||
saved_model = "{}/saved_model.pb".format(directory)
|
||||
target_triple = target_triple or target_llvm_triple()
|
||||
target_cpu = target_cpu or tfcompile_target_cpu() or ""
|
||||
variables_to_feed = variables_to_feed or "''"
|
||||
if checkpoint_path:
|
||||
checkpoint_cmd_args = (
|
||||
@ -131,6 +136,7 @@ def saved_model_compile_aot(
|
||||
"--variables_to_feed {} ".format(variables_to_feed) +
|
||||
"--signature_def_key {} ".format(signature_def) +
|
||||
"--target_triple " + target_triple + " " +
|
||||
("--target_cpu " + target_cpu + " " if target_cpu else "") +
|
||||
"--tag_set {} ".format(tag_set)
|
||||
),
|
||||
tags = tags,
|
||||
|
@ -2866,7 +2866,7 @@ def if_mlir(if_true, if_false = []):
|
||||
"//conditions:default": if_false,
|
||||
})
|
||||
|
||||
def tfcompile_extra_flags():
|
||||
def tfcompile_target_cpu():
|
||||
return ""
|
||||
|
||||
def tf_external_workspace_visible(visibility):
|
||||
|
Loading…
Reference in New Issue
Block a user