From bdc490c868de3288d9bb57e0086b3ce4022a0fb4 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Fri, 22 Jan 2021 16:32:02 -0800 Subject: [PATCH] [TPU] Colocate tpu_compilation_device registration with other registration in compiler/tf2xla This is part of a series of changes to move TPU-related code to better locations so that the TensorFlow build isn't confused and TPU-based TF can be built without the define=framework_shared_object=false flag. PiperOrigin-RevId: 353341905 Change-Id: I7cd820aab37c3d4f1a838967a87dc462491e446f --- tensorflow/compiler/tf2xla/BUILD | 28 ++++++++++++++----- .../tf2xla/xla_tpu_backend.cc} | 0 tensorflow/core/tpu/BUILD | 15 ---------- 3 files changed, 21 insertions(+), 22 deletions(-) rename tensorflow/{core/tpu/tpu_compilation_device.cc => compiler/tf2xla/xla_tpu_backend.cc} (100%) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index e399eece0e1..31329590f1d 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -350,11 +350,16 @@ cc_library( ":xla_helpers", ":xla_op_registry", ":xla_resource", + "//tensorflow/compiler/mlir:mlir_bridge_rollout_policy", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/mlir:array_container_utils", - "//tensorflow/compiler/mlir:mlir_bridge_rollout_policy", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", @@ -373,12 +378,9 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", - ], + ] + if_libtpu([ + ":xla_tpu_backend_registration", + ]), alwayslink = 1, ) @@ -405,6 +407,18 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_tpu_backend_registration", + srcs = ["xla_tpu_backend.cc"], + visibility = ["//visibility:public"], + deps = [ + ":xla_op_registry", + "//tensorflow/core/tpu:tpu_defs", + "//tensorflow/core/tpu:tpu_node_device_util", + ], + alwayslink = 1, +) + cc_library( name = "xla_context", srcs = [ diff --git a/tensorflow/core/tpu/tpu_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_tpu_backend.cc similarity index 100% rename from tensorflow/core/tpu/tpu_compilation_device.cc rename to tensorflow/compiler/tf2xla/xla_tpu_backend.cc diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 3a4114793cf..e2e7aacb9cd 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -55,19 +55,6 @@ cc_library( ], ) -cc_library( - name = "tpu_compilation_device", - srcs = ["tpu_compilation_device.cc"], - visibility = ["//visibility:public"], - deps = [ - ":tpu_defs", - ":tpu_node_device_util", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla:xla_op_registry", - ], - alwayslink = 1, -) - cc_library( name = "tpu_node_device_util", srcs = ["tpu_node_device_util.cc"], @@ -149,7 +136,6 @@ cc_library( ":libtftpu_header", ":tpu_api", ":tpu_api_dlsym_set_fn", - ":tpu_compilation_device", ":tpu_executor_init_fns", ":tpu_library_init_fns", ":tpu_ops_c_api_hdrs", @@ -285,7 +271,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":tpu_api_dlsym_initializer", - ":tpu_compilation_device", "//tensorflow/core/tpu:tpu_on_demand_compiler", "//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration", "//tensorflow/core/tpu/ops",