From 7ea6fcf6dbefa9267a428019fece31667cfa4f2f Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Thu, 21 Jan 2021 16:05:46 -0800 Subject: [PATCH] [TPU] Colocate tpu_computation_placer with other registration in compiler/xla This is part of a series of changes to move TPU-related code to better locations so that the TensorFlow build isn't confused and TPU-based TF can be built without the define=framework_shared_object=false flag. PiperOrigin-RevId: 353124524 Change-Id: Ib0157a70a9ea5d55be2cf754b07a0b8b5e84e90f --- tensorflow/compiler/xla/pjrt/BUILD | 2 +- tensorflow/compiler/xla/pjrt/tpu_client.cc | 2 +- tensorflow/compiler/xla/service/BUILD | 44 ++++++++++++++++--- .../xla/service}/tpu_computation_placer.cc | 2 +- .../xla/service}/tpu_computation_placer.h | 0 tensorflow/core/tpu/BUILD | 4 +- tensorflow/stream_executor/tpu/BUILD | 31 ++++++------- 7 files changed, 58 insertions(+), 27 deletions(-) rename tensorflow/{stream_executor/tpu => compiler/xla/service}/tpu_computation_placer.cc (97%) rename tensorflow/{stream_executor/tpu => compiler/xla/service}/tpu_computation_placer.h (100%) diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index 705d8b2f56b..fcf840d2fdb 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -242,13 +242,13 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:tpu_computation_placer", "//tensorflow/core:lib", "//tensorflow/core/tpu:tpu_executor_dlsym_initializer", "//tensorflow/core/tpu:tpu_on_demand_compiler", "//tensorflow/stream_executor:device_memory", "//tensorflow/stream_executor:stream", "//tensorflow/stream_executor/lib", - "//tensorflow/stream_executor/tpu:tpu_computation_placer", "//tensorflow/stream_executor/tpu:tpu_executable_interface", "//tensorflow/stream_executor/tpu:tpu_executor", "//tensorflow/stream_executor/tpu:tpu_executor_interface", diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc index 95f581479e5..068bd73a0d0 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.cc +++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/tpu_computation_placer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/stream.h" -#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h" #include "tensorflow/stream_executor/tpu/tpu_executable_interface.h" #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h" #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ea828c54553..61ce56ced2b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -7,7 +7,7 @@ load( "//tensorflow/core/platform:build_config.bzl", "tf_proto_library", ) -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_cc_test") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "filegroup") @@ -2872,6 +2872,9 @@ cc_library( hdrs = ["computation_placer.h"], deps = [ ":global_device_id", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -2887,13 +2890,44 @@ cc_library( "//tensorflow/stream_executor/cuda:cuda_platform_id", "//tensorflow/stream_executor/host:host_platform_id", "//tensorflow/stream_executor/rocm:rocm_platform_id", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - ], + ] + if_libtpu([":tpu_computation_placer"]), alwayslink = True, # Contains per-platform computation placer registration ) +cc_library( + name = "computation_placer_hdr", + hdrs = ["computation_placer.h"], + deps = [ + ":global_device_id", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:tflite_portable_logging", + "//tensorflow/core/platform:macros", + "//tensorflow/stream_executor:stream_header", + ], +) + +cc_library( + name = "tpu_computation_placer", + srcs = ["tpu_computation_placer.cc"], + hdrs = ["tpu_computation_placer.h"], + visibility = ["//visibility:public"], + deps = [ + ":computation_placer_hdr", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core/tpu:tpu_api", + "//tensorflow/stream_executor/tpu:status_helper", + "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", + "//tensorflow/stream_executor/tpu:tpu_platform_hdr", + "//tensorflow/stream_executor/tpu:tpu_topology_external", + ], + alwayslink = True, # Contains TPU computation placer registration +) + cc_library( name = "human_readable_profile_builder", srcs = ["human_readable_profile_builder.cc"], diff --git a/tensorflow/stream_executor/tpu/tpu_computation_placer.cc b/tensorflow/compiler/xla/service/tpu_computation_placer.cc similarity index 97% rename from tensorflow/stream_executor/tpu/tpu_computation_placer.cc rename to tensorflow/compiler/xla/service/tpu_computation_placer.cc index 40e0117daad..a60f1ae1f68 100644 --- a/tensorflow/stream_executor/tpu/tpu_computation_placer.cc +++ b/tensorflow/compiler/xla/service/tpu_computation_placer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h" +#include "tensorflow/compiler/xla/service/tpu_computation_placer.h" #include "tensorflow/core/tpu/tpu_api.h" #include "tensorflow/stream_executor/tpu/status_helper.h" diff --git a/tensorflow/stream_executor/tpu/tpu_computation_placer.h b/tensorflow/compiler/xla/service/tpu_computation_placer.h similarity index 100% rename from tensorflow/stream_executor/tpu/tpu_computation_placer.h rename to tensorflow/compiler/xla/service/tpu_computation_placer.h diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index b7c294df7ef..e8995d7c3b5 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -153,9 +153,9 @@ cc_library( ":tpu_executor_init_fns", ":tpu_library_init_fns", ":tpu_ops_c_api_hdrs", + "//tensorflow/compiler/xla/service:tpu_computation_placer", "//tensorflow/core:lib", "//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration", - "//tensorflow/stream_executor/tpu:tpu_computation_placer", "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", ], ) @@ -173,8 +173,8 @@ cc_library( deps = [ ":tpu_api_dlsym_set_fn", ":tpu_executor_init_fns", + "//tensorflow/compiler/xla/service:tpu_computation_placer", "//tensorflow/core:lib", - "//tensorflow/stream_executor/tpu:tpu_computation_placer", "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", ], alwayslink = True, diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD index 409d1ccce54..20847f2e3a2 100644 --- a/tensorflow/stream_executor/tpu/BUILD +++ b/tensorflow/stream_executor/tpu/BUILD @@ -6,6 +6,7 @@ package( default_visibility = [ "//learning/brain/experimental/dtensor:__subpackages__", "//tensorflow/compiler/jit:__subpackages__", + "//tensorflow/compiler/xla:__subpackages__", "//tensorflow/compiler/xrt:__subpackages__", "//tensorflow/core/profiler/internal/tpu:__subpackages__", "//tensorflow/core/tpu:__subpackages__", @@ -126,6 +127,19 @@ cc_library( alwayslink = True, ) +cc_library( + name = "tpu_platform_hdr", + hdrs = ["tpu_platform.h"], + deps = [ + ":tpu_executor_c_api_hdrs", + ":tpu_platform_interface", + "//tensorflow/core:framework_lite", + "//tensorflow/core:tflite_portable_logging", + "//tensorflow/stream_executor:stream_header", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + cc_library( name = "tpu_executor_hdrs", hdrs = [ @@ -250,23 +264,6 @@ cc_library( ], ) -cc_library( - name = "tpu_computation_placer", - srcs = ["tpu_computation_placer.cc"], - hdrs = ["tpu_computation_placer.h"], - visibility = ["//visibility:public"], - deps = [ - ":status_helper", - ":tpu_executor", - ":tpu_executor_c_api_hdrs", - ":tpu_topology_external", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/core/tpu:tpu_api", - ], - alwayslink = True, -) - cc_library( name = "tpu_executable", srcs = ["tpu_executable.cc"],