[TF saved_model_cli] Allow user to set target_cpu for xla aot compilation.

PiperOrigin-RevId: 312290453
Change-Id: I024e2b3884e436578e351d43961199e4c28307c3
This commit is contained in:
Eugene Brevdo 2020-05-19 09:04:05 -07:00 committed by TensorFlower Gardener
parent fb7ba76670
commit e25d9862ca
5 changed files with 24 additions and 4 deletions

View File

@ -20,7 +20,7 @@ load(
"tf_cc_test", "tf_cc_test",
"tf_copts", "tf_copts",
) )
load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags") load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
def tf_library( def tf_library(
name, name,
@ -188,7 +188,9 @@ def tf_library(
# `find` on such an object. # `find` on such an object.
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1 need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
flags = tfcompile_extra_flags() + flags target_cpu = tfcompile_target_cpu()
extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " "
flags = extra_flags + flags
if enable_xla_hlo_profiling: if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile" profiling_flag = "--xla_hlo_profile"

View File

@ -215,6 +215,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
signature_def_key, signature_def_key,
cpp_class, cpp_class,
target_triple, target_triple,
target_cpu,
variables_to_feed=(), variables_to_feed=(),
enable_multithreading=False): enable_multithreading=False):
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`. """Compile a `MetaGraphDef` to header+object files in `output_prefix`.
@ -239,6 +240,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
signature_def_key: String, the signature_def to use in the SavedModel. signature_def_key: String, the signature_def to use in the SavedModel.
cpp_class: String, Name of output C++ class. cpp_class: String, Name of output C++ class.
target_triple: String, LLVM target triple. target_triple: String, LLVM target triple.
target_cpu: String, LLVM target cpu name.
variables_to_feed: A list of strings, the variables that will be fed by the variables_to_feed: A list of strings, the variables that will be fed by the
user; these won't be frozen. If `None`, then we will extract all the user; these won't be frozen. If `None`, then we will extract all the
variables in the graph and mark them as to-feed. The default behavior is variables in the graph and mark them as to-feed. The default behavior is
@ -367,6 +369,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
config=config_pbtxt_location, config=config_pbtxt_location,
cpp_class=cpp_class, cpp_class=cpp_class,
target_triple=target_triple, target_triple=target_triple,
target_cpu=target_cpu,
entry_point='entry_{}'.format(entry_digest), entry_point='entry_{}'.format(entry_digest),
out_function_object='{}.o'.format(output_prefix), out_function_object='{}.o'.format(output_prefix),
out_header='{}.h'.format(output_prefix), out_header='{}.h'.format(output_prefix),

View File

@ -825,6 +825,7 @@ def aot_compile_cpu(args):
variables_to_feed=variables_to_feed, variables_to_feed=variables_to_feed,
output_prefix=args.output_prefix, output_prefix=args.output_prefix,
target_triple=args.target_triple, target_triple=args.target_triple,
target_cpu=args.target_cpu,
cpp_class=args.cpp_class, cpp_class=args.cpp_class,
enable_multithreading=args.enable_multithreading) enable_multithreading=args.enable_multithreading)
@ -1096,6 +1097,14 @@ def add_aot_compile_cpu_subparser(subparsers):
'x86_64-none-darwin, x86_64-apple-ios, arm64-none-ios, ' 'x86_64-none-darwin, x86_64-apple-ios, arm64-none-ios, '
'armv7-none-android. More examples are available in tfcompile.bzl ' 'armv7-none-android. More examples are available in tfcompile.bzl '
'in the tensorflow codebase.')) 'in the tensorflow codebase.'))
parser_compile.add_argument(
'--target_cpu',
type=str,
default='',
help=('Target cpu name for LLVM during AOT compilation. Examples: '
'x86_64, skylake, haswell, westmere, <empty> (unknown). For '
'a complete list of options, run (for x86 targets): '
'`llc -march=x86 -mcpu=help`'))
parser_compile.add_argument( parser_compile.add_argument(
'--checkpoint_path', '--checkpoint_path',
type=str, type=str,

View File

@ -1,6 +1,7 @@
"""Definitions for using tools like saved_model_cli.""" """Definitions for using tools like saved_model_cli."""
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_xla_available") load("//tensorflow:tensorflow.bzl", "clean_dep", "if_xla_available")
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
load("//tensorflow/compiler/aot:tfcompile.bzl", "target_llvm_triple") load("//tensorflow/compiler/aot:tfcompile.bzl", "target_llvm_triple")
def _maybe_force_compile(args, force_compile): def _maybe_force_compile(args, force_compile):
@ -19,6 +20,7 @@ def saved_model_compile_aot(
signature_def = "serving_default", signature_def = "serving_default",
variables_to_feed = "", variables_to_feed = "",
target_triple = None, target_triple = None,
target_cpu = None,
force_without_xla_support_flag = True, force_without_xla_support_flag = True,
tags = None): tags = None):
"""Compile a SavedModel directory accessible from a filegroup. """Compile a SavedModel directory accessible from a filegroup.
@ -88,7 +90,9 @@ def saved_model_compile_aot(
uninitialized in the compiled object (this applies to all input uninitialized in the compiled object (this applies to all input
arguments from the signature as well). arguments from the signature as well).
target_triple: The LLVM target triple to use (defaults to current build's target_triple: The LLVM target triple to use (defaults to current build's
target architecture's triple). target architecture's triple). Similar to clang's -target flag.
target_cpu: The LLVM cpu name used for compilation. Similar to clang's
-mcpu flag.
force_without_xla_support_flag: Whether to compile even when force_without_xla_support_flag: Whether to compile even when
`--define=with_xla_support=true` is not set. If `False`, and the `--define=with_xla_support=true` is not set. If `False`, and the
define is not passed when building, then the created `cc_library` define is not passed when building, then the created `cc_library`
@ -100,6 +104,7 @@ def saved_model_compile_aot(
""" """
saved_model = "{}/saved_model.pb".format(directory) saved_model = "{}/saved_model.pb".format(directory)
target_triple = target_triple or target_llvm_triple() target_triple = target_triple or target_llvm_triple()
target_cpu = target_cpu or tfcompile_target_cpu() or ""
variables_to_feed = variables_to_feed or "''" variables_to_feed = variables_to_feed or "''"
if checkpoint_path: if checkpoint_path:
checkpoint_cmd_args = ( checkpoint_cmd_args = (
@ -131,6 +136,7 @@ def saved_model_compile_aot(
"--variables_to_feed {} ".format(variables_to_feed) + "--variables_to_feed {} ".format(variables_to_feed) +
"--signature_def_key {} ".format(signature_def) + "--signature_def_key {} ".format(signature_def) +
"--target_triple " + target_triple + " " + "--target_triple " + target_triple + " " +
("--target_cpu " + target_cpu + " " if target_cpu else "") +
"--tag_set {} ".format(tag_set) "--tag_set {} ".format(tag_set)
), ),
tags = tags, tags = tags,

View File

@ -2866,7 +2866,7 @@ def if_mlir(if_true, if_false = []):
"//conditions:default": if_false, "//conditions:default": if_false,
}) })
def tfcompile_extra_flags(): def tfcompile_target_cpu():
return "" return ""
def tf_external_workspace_visible(visibility): def tf_external_workspace_visible(visibility):