[TPU] Colocate tpu_computation_placer with other registration in compiler/xla

This is part of a series of changes to move TPU-related code to better locations so that the TensorFlow build isn't confused and TPU-based TF can be built without the define=framework_shared_object=false flag.

PiperOrigin-RevId: 353124524
Change-Id: Ib0157a70a9ea5d55be2cf754b07a0b8b5e84e90f
This commit is contained in:
Frank Chen 2021-01-21 16:05:46 -08:00 committed by TensorFlower Gardener
parent c7e57df256
commit 7ea6fcf6db
7 changed files with 58 additions and 27 deletions
tensorflow

View File

@ -242,13 +242,13 @@ cc_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:tpu_computation_placer",
"//tensorflow/core:lib",
"//tensorflow/core/tpu:tpu_executor_dlsym_initializer",
"//tensorflow/core/tpu:tpu_on_demand_compiler",
"//tensorflow/stream_executor:device_memory",
"//tensorflow/stream_executor:stream",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/tpu:tpu_computation_placer",
"//tensorflow/stream_executor/tpu:tpu_executable_interface",
"//tensorflow/stream_executor/tpu:tpu_executor",
"//tensorflow/stream_executor/tpu:tpu_executor_interface",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/tpu_computation_placer.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h"
#include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"

View File

@ -7,7 +7,7 @@ load(
"//tensorflow/core/platform:build_config.bzl",
"tf_proto_library",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_cc_test")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "filegroup")
@ -2872,6 +2872,9 @@ cc_library(
hdrs = ["computation_placer.h"],
deps = [
":global_device_id",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -2887,13 +2890,44 @@ cc_library(
"//tensorflow/stream_executor/cuda:cuda_platform_id",
"//tensorflow/stream_executor/host:host_platform_id",
"//tensorflow/stream_executor/rocm:rocm_platform_id",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
] + if_libtpu([":tpu_computation_placer"]),
alwayslink = True, # Contains per-platform computation placer registration
)
cc_library(
name = "computation_placer_hdr",
hdrs = ["computation_placer.h"],
deps = [
":global_device_id",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/core/platform:macros",
"//tensorflow/stream_executor:stream_header",
],
)
cc_library(
name = "tpu_computation_placer",
srcs = ["tpu_computation_placer.cc"],
hdrs = ["tpu_computation_placer.h"],
visibility = ["//visibility:public"],
deps = [
":computation_placer_hdr",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/stream_executor/tpu:status_helper",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_platform_hdr",
"//tensorflow/stream_executor/tpu:tpu_topology_external",
],
alwayslink = True, # Contains TPU computation placer registration
)
cc_library(
name = "human_readable_profile_builder",
srcs = ["human_readable_profile_builder.cc"],

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h"
#include "tensorflow/compiler/xla/service/tpu_computation_placer.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"

View File

@ -153,9 +153,9 @@ cc_library(
":tpu_executor_init_fns",
":tpu_library_init_fns",
":tpu_ops_c_api_hdrs",
"//tensorflow/compiler/xla/service:tpu_computation_placer",
"//tensorflow/core:lib",
"//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
"//tensorflow/stream_executor/tpu:tpu_computation_placer",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
],
)
@ -173,8 +173,8 @@ cc_library(
deps = [
":tpu_api_dlsym_set_fn",
":tpu_executor_init_fns",
"//tensorflow/compiler/xla/service:tpu_computation_placer",
"//tensorflow/core:lib",
"//tensorflow/stream_executor/tpu:tpu_computation_placer",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
],
alwayslink = True,

View File

@ -6,6 +6,7 @@ package(
default_visibility = [
"//learning/brain/experimental/dtensor:__subpackages__",
"//tensorflow/compiler/jit:__subpackages__",
"//tensorflow/compiler/xla:__subpackages__",
"//tensorflow/compiler/xrt:__subpackages__",
"//tensorflow/core/profiler/internal/tpu:__subpackages__",
"//tensorflow/core/tpu:__subpackages__",
@ -126,6 +127,19 @@ cc_library(
alwayslink = True,
)
cc_library(
name = "tpu_platform_hdr",
hdrs = ["tpu_platform.h"],
deps = [
":tpu_executor_c_api_hdrs",
":tpu_platform_interface",
"//tensorflow/core:framework_lite",
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/stream_executor:stream_header",
"@com_google_absl//absl/container:flat_hash_map",
],
)
cc_library(
name = "tpu_executor_hdrs",
hdrs = [
@ -250,23 +264,6 @@ cc_library(
],
)
cc_library(
name = "tpu_computation_placer",
srcs = ["tpu_computation_placer.cc"],
hdrs = ["tpu_computation_placer.h"],
visibility = ["//visibility:public"],
deps = [
":status_helper",
":tpu_executor",
":tpu_executor_c_api_hdrs",
":tpu_topology_external",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/core/tpu:tpu_api",
],
alwayslink = True,
)
cc_library(
name = "tpu_executable",
srcs = ["tpu_executable.cc"],