From e25d9862ca5c42997112c564f1253fd001bc4a15 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 19 May 2020 09:04:05 -0700 Subject: [PATCH] [TF saved_model_cli] Allow user to set target_cpu for xla aot compilation. PiperOrigin-RevId: 312290453 Change-Id: I024e2b3884e436578e351d43961199e4c28307c3 --- tensorflow/compiler/aot/tfcompile.bzl | 6 ++++-- tensorflow/python/tools/saved_model_aot_compile.py | 3 +++ tensorflow/python/tools/saved_model_cli.py | 9 +++++++++ tensorflow/python/tools/tools.bzl | 8 +++++++- tensorflow/tensorflow.bzl | 2 +- 5 files changed, 24 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 208b01c49d5..f2b28e70ff1 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -20,7 +20,7 @@ load( "tf_cc_test", "tf_copts", ) -load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags") +load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu") def tf_library( name, @@ -188,7 +188,9 @@ def tf_library( # `find` on such an object. need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1 - flags = tfcompile_extra_flags() + flags + target_cpu = tfcompile_target_cpu() + extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " " + flags = extra_flags + flags if enable_xla_hlo_profiling: profiling_flag = "--xla_hlo_profile" diff --git a/tensorflow/python/tools/saved_model_aot_compile.py b/tensorflow/python/tools/saved_model_aot_compile.py index a8694454ef2..5a34d10420a 100644 --- a/tensorflow/python/tools/saved_model_aot_compile.py +++ b/tensorflow/python/tools/saved_model_aot_compile.py @@ -215,6 +215,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, signature_def_key, cpp_class, target_triple, + target_cpu, variables_to_feed=(), enable_multithreading=False): """Compile a `MetaGraphDef` to header+object files in `output_prefix`. @@ -239,6 +240,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, signature_def_key: String, the signature_def to use in the SavedModel. cpp_class: String, Name of output C++ class. target_triple: String, LLVM target triple. + target_cpu: String, LLVM target cpu name. variables_to_feed: A list of strings, the variables that will be fed by the user; these won't be frozen. If `None`, then we will extract all the variables in the graph and mark them as to-feed. The default behavior is @@ -367,6 +369,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, config=config_pbtxt_location, cpp_class=cpp_class, target_triple=target_triple, + target_cpu=target_cpu, entry_point='entry_{}'.format(entry_digest), out_function_object='{}.o'.format(output_prefix), out_header='{}.h'.format(output_prefix), diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 261ee1b9e9d..0f8f68436a3 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -825,6 +825,7 @@ def aot_compile_cpu(args): variables_to_feed=variables_to_feed, output_prefix=args.output_prefix, target_triple=args.target_triple, + target_cpu=args.target_cpu, cpp_class=args.cpp_class, enable_multithreading=args.enable_multithreading) @@ -1096,6 +1097,14 @@ def add_aot_compile_cpu_subparser(subparsers): 'x86_64-none-darwin, x86_64-apple-ios, arm64-none-ios, ' 'armv7-none-android. More examples are available in tfcompile.bzl ' 'in the tensorflow codebase.')) + parser_compile.add_argument( + '--target_cpu', + type=str, + default='', + help=('Target cpu name for LLVM during AOT compilation. Examples: ' + 'x86_64, skylake, haswell, westmere, (unknown). For ' + 'a complete list of options, run (for x86 targets): ' + '`llc -march=x86 -mcpu=help`')) parser_compile.add_argument( '--checkpoint_path', type=str, diff --git a/tensorflow/python/tools/tools.bzl b/tensorflow/python/tools/tools.bzl index c6853e1fc63..79f771bbcad 100644 --- a/tensorflow/python/tools/tools.bzl +++ b/tensorflow/python/tools/tools.bzl @@ -1,6 +1,7 @@ """Definitions for using tools like saved_model_cli.""" load("//tensorflow:tensorflow.bzl", "clean_dep", "if_xla_available") +load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu") load("//tensorflow/compiler/aot:tfcompile.bzl", "target_llvm_triple") def _maybe_force_compile(args, force_compile): @@ -19,6 +20,7 @@ def saved_model_compile_aot( signature_def = "serving_default", variables_to_feed = "", target_triple = None, + target_cpu = None, force_without_xla_support_flag = True, tags = None): """Compile a SavedModel directory accessible from a filegroup. @@ -88,7 +90,9 @@ def saved_model_compile_aot( uninitialized in the compiled object (this applies to all input arguments from the signature as well). target_triple: The LLVM target triple to use (defaults to current build's - target architecture's triple). + target architecture's triple). Similar to clang's -target flag. + target_cpu: The LLVM cpu name used for compilation. Similar to clang's + -mcpu flag. force_without_xla_support_flag: Whether to compile even when `--define=with_xla_support=true` is not set. If `False`, and the define is not passed when building, then the created `cc_library` @@ -100,6 +104,7 @@ def saved_model_compile_aot( """ saved_model = "{}/saved_model.pb".format(directory) target_triple = target_triple or target_llvm_triple() + target_cpu = target_cpu or tfcompile_target_cpu() or "" variables_to_feed = variables_to_feed or "''" if checkpoint_path: checkpoint_cmd_args = ( @@ -131,6 +136,7 @@ def saved_model_compile_aot( "--variables_to_feed {} ".format(variables_to_feed) + "--signature_def_key {} ".format(signature_def) + "--target_triple " + target_triple + " " + + ("--target_cpu " + target_cpu + " " if target_cpu else "") + "--tag_set {} ".format(tag_set) ), tags = tags, diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 70b03146f34..9a780839be3 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -2866,7 +2866,7 @@ def if_mlir(if_true, if_false = []): "//conditions:default": if_false, }) -def tfcompile_extra_flags(): +def tfcompile_target_cpu(): return "" def tf_external_workspace_visible(visibility):