[TF saved_model_cli] Allow user to set target_cpu for xla aot compilation.
PiperOrigin-RevId: 312290453 Change-Id: I024e2b3884e436578e351d43961199e4c28307c3
This commit is contained in:
		
							parent
							
								
									fb7ba76670
								
							
						
					
					
						commit
						e25d9862ca
					
				@ -20,7 +20,7 @@ load(
 | 
			
		||||
    "tf_cc_test",
 | 
			
		||||
    "tf_copts",
 | 
			
		||||
)
 | 
			
		||||
load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags")
 | 
			
		||||
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
 | 
			
		||||
 | 
			
		||||
def tf_library(
 | 
			
		||||
        name,
 | 
			
		||||
@ -188,7 +188,9 @@ def tf_library(
 | 
			
		||||
    # `find` on such an object.
 | 
			
		||||
    need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
 | 
			
		||||
 | 
			
		||||
    flags = tfcompile_extra_flags() + flags
 | 
			
		||||
    target_cpu = tfcompile_target_cpu()
 | 
			
		||||
    extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " "
 | 
			
		||||
    flags = extra_flags + flags
 | 
			
		||||
 | 
			
		||||
    if enable_xla_hlo_profiling:
 | 
			
		||||
        profiling_flag = "--xla_hlo_profile"
 | 
			
		||||
 | 
			
		||||
@ -215,6 +215,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
 | 
			
		||||
                                   signature_def_key,
 | 
			
		||||
                                   cpp_class,
 | 
			
		||||
                                   target_triple,
 | 
			
		||||
                                   target_cpu,
 | 
			
		||||
                                   variables_to_feed=(),
 | 
			
		||||
                                   enable_multithreading=False):
 | 
			
		||||
  """Compile a `MetaGraphDef` to header+object files in `output_prefix`.
 | 
			
		||||
@ -239,6 +240,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
 | 
			
		||||
    signature_def_key: String, the signature_def to use in the SavedModel.
 | 
			
		||||
    cpp_class: String, Name of output C++ class.
 | 
			
		||||
    target_triple: String, LLVM target triple.
 | 
			
		||||
    target_cpu: String, LLVM target cpu name.
 | 
			
		||||
    variables_to_feed: A list of strings, the variables that will be fed by the
 | 
			
		||||
      user; these won't be frozen.  If `None`, then we will extract all the
 | 
			
		||||
      variables in the graph and mark them as to-feed.  The default behavior is
 | 
			
		||||
@ -367,6 +369,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
 | 
			
		||||
      config=config_pbtxt_location,
 | 
			
		||||
      cpp_class=cpp_class,
 | 
			
		||||
      target_triple=target_triple,
 | 
			
		||||
      target_cpu=target_cpu,
 | 
			
		||||
      entry_point='entry_{}'.format(entry_digest),
 | 
			
		||||
      out_function_object='{}.o'.format(output_prefix),
 | 
			
		||||
      out_header='{}.h'.format(output_prefix),
 | 
			
		||||
 | 
			
		||||
@ -825,6 +825,7 @@ def aot_compile_cpu(args):
 | 
			
		||||
      variables_to_feed=variables_to_feed,
 | 
			
		||||
      output_prefix=args.output_prefix,
 | 
			
		||||
      target_triple=args.target_triple,
 | 
			
		||||
      target_cpu=args.target_cpu,
 | 
			
		||||
      cpp_class=args.cpp_class,
 | 
			
		||||
      enable_multithreading=args.enable_multithreading)
 | 
			
		||||
 | 
			
		||||
@ -1096,6 +1097,14 @@ def add_aot_compile_cpu_subparser(subparsers):
 | 
			
		||||
            'x86_64-none-darwin, x86_64-apple-ios, arm64-none-ios, '
 | 
			
		||||
            'armv7-none-android.  More examples are available in tfcompile.bzl '
 | 
			
		||||
            'in the tensorflow codebase.'))
 | 
			
		||||
  parser_compile.add_argument(
 | 
			
		||||
      '--target_cpu',
 | 
			
		||||
      type=str,
 | 
			
		||||
      default='',
 | 
			
		||||
      help=('Target cpu name for LLVM during AOT compilation.  Examples: '
 | 
			
		||||
            'x86_64, skylake, haswell, westmere, <empty> (unknown).  For '
 | 
			
		||||
            'a complete list of options, run (for x86 targets): '
 | 
			
		||||
            '`llc -march=x86 -mcpu=help`'))
 | 
			
		||||
  parser_compile.add_argument(
 | 
			
		||||
      '--checkpoint_path',
 | 
			
		||||
      type=str,
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
"""Definitions for using tools like saved_model_cli."""
 | 
			
		||||
 | 
			
		||||
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_xla_available")
 | 
			
		||||
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
 | 
			
		||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "target_llvm_triple")
 | 
			
		||||
 | 
			
		||||
def _maybe_force_compile(args, force_compile):
 | 
			
		||||
@ -19,6 +20,7 @@ def saved_model_compile_aot(
 | 
			
		||||
        signature_def = "serving_default",
 | 
			
		||||
        variables_to_feed = "",
 | 
			
		||||
        target_triple = None,
 | 
			
		||||
        target_cpu = None,
 | 
			
		||||
        force_without_xla_support_flag = True,
 | 
			
		||||
        tags = None):
 | 
			
		||||
    """Compile a SavedModel directory accessible from a filegroup.
 | 
			
		||||
@ -88,7 +90,9 @@ def saved_model_compile_aot(
 | 
			
		||||
        uninitialized in the compiled object (this applies to all input
 | 
			
		||||
        arguments from the signature as well).
 | 
			
		||||
      target_triple: The LLVM target triple to use (defaults to current build's
 | 
			
		||||
        target architecture's triple).
 | 
			
		||||
        target architecture's triple).  Similar to clang's -target flag.
 | 
			
		||||
      target_cpu: The LLVM cpu name used for compilation.  Similar to clang's
 | 
			
		||||
        -mcpu flag.
 | 
			
		||||
      force_without_xla_support_flag: Whether to compile even when
 | 
			
		||||
        `--define=with_xla_support=true` is not set.  If `False`, and the
 | 
			
		||||
        define is not passed when building, then the created `cc_library`
 | 
			
		||||
@ -100,6 +104,7 @@ def saved_model_compile_aot(
 | 
			
		||||
    """
 | 
			
		||||
    saved_model = "{}/saved_model.pb".format(directory)
 | 
			
		||||
    target_triple = target_triple or target_llvm_triple()
 | 
			
		||||
    target_cpu = target_cpu or tfcompile_target_cpu() or ""
 | 
			
		||||
    variables_to_feed = variables_to_feed or "''"
 | 
			
		||||
    if checkpoint_path:
 | 
			
		||||
        checkpoint_cmd_args = (
 | 
			
		||||
@ -131,6 +136,7 @@ def saved_model_compile_aot(
 | 
			
		||||
            "--variables_to_feed {} ".format(variables_to_feed) +
 | 
			
		||||
            "--signature_def_key {} ".format(signature_def) +
 | 
			
		||||
            "--target_triple " + target_triple + " " +
 | 
			
		||||
            ("--target_cpu " + target_cpu + " " if target_cpu else "") +
 | 
			
		||||
            "--tag_set {} ".format(tag_set)
 | 
			
		||||
        ),
 | 
			
		||||
        tags = tags,
 | 
			
		||||
 | 
			
		||||
@ -2866,7 +2866,7 @@ def if_mlir(if_true, if_false = []):
 | 
			
		||||
        "//conditions:default": if_false,
 | 
			
		||||
    })
 | 
			
		||||
 | 
			
		||||
def tfcompile_extra_flags():
 | 
			
		||||
def tfcompile_target_cpu():
 | 
			
		||||
    return ""
 | 
			
		||||
 | 
			
		||||
def tf_external_workspace_visible(visibility):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user