diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD new file mode 100644 index 00000000000..ddc4c007c09 --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -0,0 +1,212 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") + +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "worker_thread", + srcs = ["worker_thread.cc"], + hdrs = ["worker_thread.h"], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "event_pool", + srcs = ["event_pool.cc"], + hdrs = ["event_pool.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "semaphore", + srcs = ["semaphore.cc"], + hdrs = ["semaphore.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + ], +) + +tf_cc_test( + name = "semaphore_test", + srcs = ["semaphore_test.cc"], + deps = [ + ":semaphore", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "tracked_device_buffer", + srcs = ["tracked_device_buffer.cc"], + hdrs = ["tracked_device_buffer.h"], + deps = [ + ":event_pool", + ":local_device_state", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/core:lib", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:device_memory_allocator", + "//tensorflow/stream_executor:event", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/synchronization", + ], +) + +tf_cc_test( + name = "tracked_device_buffer_test", + srcs = ["tracked_device_buffer_test.cc"], + deps = [ + ":tracked_device_buffer", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:device_memory_allocator", + ], +) + +cc_library( + name = "local_device_state", + srcs = ["local_device_state.cc"], + hdrs = ["local_device_state.h"], + deps = [ + ":event_pool", + ":semaphore", + ":worker_thread", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", + "//tensorflow/stream_executor:event", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "pjrt_client", + srcs = ["pjrt_client.cc"], + hdrs = ["pjrt_client.h"], + visibility = ["//tensorflow/compiler/xla:friends"], + deps = [ + ":event_pool", + ":local_device_state", + ":tracked_device_buffer", + "//tensorflow/compiler/xla:cpu_function_runtime", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/pjrt/distributed:protocol_proto_cc", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:maybe_owning_device_memory", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", + "//tensorflow/core:allocator", + "//tensorflow/core:lib", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor/host:host_platform_id", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "cpu_device", + srcs = ["cpu_device.cc"], + hdrs = ["cpu_device.h"], + deps = [ + ":pjrt_client", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/service:platform_util", + ], +) + +cc_library( + name = "nvidia_gpu_device", + srcs = ["nvidia_gpu_device.cc"], + hdrs = ["nvidia_gpu_device.h"], + copts = if_cuda(["-DNCCL_ENABLED=1"]), + deps = [ + ":pjrt_client", + "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/pjrt/distributed:client", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/core/common_runtime:bfc_allocator", + "//tensorflow/core/common_runtime/gpu:gpu_mem_allocator", + "//tensorflow/stream_executor:tf_allocator_adapter", + ] + if_cuda(["@local_config_nccl//:nccl"]), +) + +tf_cc_test( + name = "gpu_multistream_test", + srcs = ["gpu_multistream_test.cc"], + tags = [ + # TODO(phawkins): figure out TF test infra such that this only runs under GPU. + "no_oss", + "requires-gpu-nvidia", + ], + deps = [ + ":nvidia_gpu_device", + ":pjrt_client", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:random", + ], +) diff --git a/tensorflow/compiler/xla/python/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc similarity index 97% rename from tensorflow/compiler/xla/python/cpu_device.cc rename to tensorflow/compiler/xla/pjrt/cpu_device.cc index 12e1e55723b..f2bc472ed09 100644 --- a/tensorflow/compiler/xla/python/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/cpu_device.h" +#include "tensorflow/compiler/xla/pjrt/cpu_device.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/service/platform_util.h" diff --git a/tensorflow/compiler/xla/python/cpu_device.h b/tensorflow/compiler/xla/pjrt/cpu_device.h similarity index 81% rename from tensorflow/compiler/xla/python/cpu_device.h rename to tensorflow/compiler/xla/pjrt/cpu_device.h index 38e81644b1e..c70d90ae228 100644 --- a/tensorflow/compiler/xla/python/cpu_device.h +++ b/tensorflow/compiler/xla/pjrt/cpu_device.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_ #include -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -32,4 +32,4 @@ StatusOr> GetCpuClient(bool asynchronous); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_ diff --git a/tensorflow/compiler/xla/python/distributed/BUILD b/tensorflow/compiler/xla/pjrt/distributed/BUILD similarity index 100% rename from tensorflow/compiler/xla/python/distributed/BUILD rename to tensorflow/compiler/xla/pjrt/distributed/BUILD diff --git a/tensorflow/compiler/xla/python/distributed/client.cc b/tensorflow/compiler/xla/pjrt/distributed/client.cc similarity index 94% rename from tensorflow/compiler/xla/python/distributed/client.cc rename to tensorflow/compiler/xla/pjrt/distributed/client.cc index c50c3f50a9d..830e512b156 100644 --- a/tensorflow/compiler/xla/python/distributed/client.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/client.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" #include // NOLINT -#include "tensorflow/compiler/xla/python/distributed/protocol.h" -#include "tensorflow/compiler/xla/python/distributed/util.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.h" +#include "tensorflow/compiler/xla/pjrt/distributed/util.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/distributed/client.h b/tensorflow/compiler/xla/pjrt/distributed/client.h similarity index 85% rename from tensorflow/compiler/xla/python/distributed/client.h rename to tensorflow/compiler/xla/pjrt/distributed/client.h index 1ab5292bea8..865a752849e 100644 --- a/tensorflow/compiler/xla/python/distributed/client.h +++ b/tensorflow/compiler/xla/pjrt/distributed/client.h @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ #include #include "grpcpp/channel.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.grpc.pb.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/platform/env.h" @@ -47,4 +47,4 @@ class DistributedRuntimeClient { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ diff --git a/tensorflow/compiler/xla/python/distributed/client_server_test.cc b/tensorflow/compiler/xla/pjrt/distributed/client_server_test.cc similarity index 95% rename from tensorflow/compiler/xla/python/distributed/client_server_test.cc rename to tensorflow/compiler/xla/pjrt/distributed/client_server_test.cc index e78949933a2..cfe60a06207 100644 --- a/tensorflow/compiler/xla/python/distributed/client_server_test.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/client_server_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include "grpcpp/security/server_credentials.h" #include "absl/time/time.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/python/distributed/client.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h" -#include "tensorflow/compiler/xla/python/distributed/service.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/python/distributed/distributed.cc b/tensorflow/compiler/xla/pjrt/distributed/distributed.cc similarity index 95% rename from tensorflow/compiler/xla/python/distributed/distributed.cc rename to tensorflow/compiler/xla/pjrt/distributed/distributed.cc index 6afc7b1c4e9..7753e2dcfc7 100644 --- a/tensorflow/compiler/xla/python/distributed/distributed.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/distributed.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/distributed.h" +#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h" #include "grpcpp/grpcpp.h" diff --git a/tensorflow/compiler/xla/python/distributed/distributed.h b/tensorflow/compiler/xla/pjrt/distributed/distributed.h similarity index 83% rename from tensorflow/compiler/xla/python/distributed/distributed.h rename to tensorflow/compiler/xla/pjrt/distributed/distributed.h index 0475c3e9feb..b3909387259 100644 --- a/tensorflow/compiler/xla/python/distributed/distributed.h +++ b/tensorflow/compiler/xla/pjrt/distributed/distributed.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_ #include #include -#include "tensorflow/compiler/xla/python/distributed/client.h" -#include "tensorflow/compiler/xla/python/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -43,4 +43,4 @@ std::shared_ptr GetDistributedRuntimeClient( } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_ diff --git a/tensorflow/compiler/xla/python/distributed/key_value_store.cc b/tensorflow/compiler/xla/pjrt/distributed/key_value_store.cc similarity index 95% rename from tensorflow/compiler/xla/python/distributed/key_value_store.cc rename to tensorflow/compiler/xla/pjrt/distributed/key_value_store.cc index 5966d4ce12b..e989b1384d2 100644 --- a/tensorflow/compiler/xla/python/distributed/key_value_store.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/key_value_store.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/key_value_store.h" +#include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/distributed/key_value_store.h b/tensorflow/compiler/xla/pjrt/distributed/key_value_store.h similarity index 89% rename from tensorflow/compiler/xla/python/distributed/key_value_store.h rename to tensorflow/compiler/xla/pjrt/distributed/key_value_store.h index 8560305e6f6..d496de1feb5 100644 --- a/tensorflow/compiler/xla/python/distributed/key_value_store.h +++ b/tensorflow/compiler/xla/pjrt/distributed/key_value_store.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_ #include "grpcpp/grpcpp.h" #include "absl/base/thread_annotations.h" @@ -50,4 +50,4 @@ class KeyValueStore { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_ diff --git a/tensorflow/compiler/xla/python/distributed/protocol.h b/tensorflow/compiler/xla/pjrt/distributed/protocol.h similarity index 80% rename from tensorflow/compiler/xla/python/distributed/protocol.h rename to tensorflow/compiler/xla/pjrt/distributed/protocol.h index 208c6dab8c5..4daa939ac8d 100644 --- a/tensorflow/compiler/xla/python/distributed/protocol.h +++ b/tensorflow/compiler/xla/pjrt/distributed/protocol.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_ namespace xla { @@ -22,4 +22,4 @@ static constexpr int kDistributedRuntimeProtocolVersion = 1; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_ diff --git a/tensorflow/compiler/xla/python/distributed/protocol.proto b/tensorflow/compiler/xla/pjrt/distributed/protocol.proto similarity index 100% rename from tensorflow/compiler/xla/python/distributed/protocol.proto rename to tensorflow/compiler/xla/pjrt/distributed/protocol.proto diff --git a/tensorflow/compiler/xla/python/distributed/service.cc b/tensorflow/compiler/xla/pjrt/distributed/service.cc similarity index 96% rename from tensorflow/compiler/xla/python/distributed/service.cc rename to tensorflow/compiler/xla/pjrt/distributed/service.cc index cc2b3a5aca2..3325fcd8319 100644 --- a/tensorflow/compiler/xla/python/distributed/service.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/service.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.h" -#include "tensorflow/compiler/xla/python/distributed/util.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.h" +#include "tensorflow/compiler/xla/pjrt/distributed/util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/python/distributed/service.h b/tensorflow/compiler/xla/pjrt/distributed/service.h similarity index 91% rename from tensorflow/compiler/xla/python/distributed/service.h rename to tensorflow/compiler/xla/pjrt/distributed/service.h index baf470e4f13..725a76791ce 100644 --- a/tensorflow/compiler/xla/python/distributed/service.h +++ b/tensorflow/compiler/xla/pjrt/distributed/service.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ #include "absl/time/time.h" -#include "tensorflow/compiler/xla/python/distributed/key_value_store.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.grpc.pb.h" +#include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -98,4 +98,4 @@ void BuildGlobalTopology(absl::Span local_topologies, } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ diff --git a/tensorflow/compiler/xla/python/distributed/service_test.cc b/tensorflow/compiler/xla/pjrt/distributed/service_test.cc similarity index 91% rename from tensorflow/compiler/xla/python/distributed/service_test.cc rename to tensorflow/compiler/xla/pjrt/distributed/service_test.cc index 08326df2f38..b56dbb17d1a 100644 --- a/tensorflow/compiler/xla/python/distributed/service_test.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/service_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/python/distributed/util.h b/tensorflow/compiler/xla/pjrt/distributed/util.h similarity index 87% rename from tensorflow/compiler/xla/python/distributed/util.h rename to tensorflow/compiler/xla/pjrt/distributed/util.h index 07ae8d1f0ce..abb2b6089e7 100644 --- a/tensorflow/compiler/xla/python/distributed/util.h +++ b/tensorflow/compiler/xla/pjrt/distributed/util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_ #include "grpcpp/support/status.h" #include "tensorflow/compiler/xla/status.h" @@ -41,4 +41,4 @@ inline ::grpc::Status ToGrpcStatus(const Status& s) { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_ diff --git a/tensorflow/compiler/xla/python/event_pool.cc b/tensorflow/compiler/xla/pjrt/event_pool.cc similarity index 96% rename from tensorflow/compiler/xla/python/event_pool.cc rename to tensorflow/compiler/xla/pjrt/event_pool.cc index c7b52f523d9..86aa38cdd0f 100644 --- a/tensorflow/compiler/xla/python/event_pool.cc +++ b/tensorflow/compiler/xla/pjrt/event_pool.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/event_pool.h" +#include "tensorflow/compiler/xla/pjrt/event_pool.h" #include "absl/memory/memory.h" #include "absl/synchronization/mutex.h" diff --git a/tensorflow/compiler/xla/python/event_pool.h b/tensorflow/compiler/xla/pjrt/event_pool.h similarity index 95% rename from tensorflow/compiler/xla/python/event_pool.h rename to tensorflow/compiler/xla/pjrt/event_pool.h index bda3fb6baff..47768c28fd9 100644 --- a/tensorflow/compiler/xla/python/event_pool.h +++ b/tensorflow/compiler/xla/pjrt/event_pool.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_ #include #include @@ -87,4 +87,4 @@ class EventPool { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_ diff --git a/tensorflow/compiler/xla/python/gpu_multistream_test.cc b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc similarity index 97% rename from tensorflow/compiler/xla/python/gpu_multistream_test.cc rename to tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc index bc6ecb14ae2..2db7de3720d 100644 --- a/tensorflow/compiler/xla/python/gpu_multistream_test.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/python/local_client.h" -#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/platform/random.h" diff --git a/tensorflow/compiler/xla/python/local_device_state.cc b/tensorflow/compiler/xla/pjrt/local_device_state.cc similarity index 98% rename from tensorflow/compiler/xla/python/local_device_state.cc rename to tensorflow/compiler/xla/pjrt/local_device_state.cc index 6a96908cb12..d173c891c95 100644 --- a/tensorflow/compiler/xla/python/local_device_state.cc +++ b/tensorflow/compiler/xla/pjrt/local_device_state.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include #include diff --git a/tensorflow/compiler/xla/python/local_device_state.h b/tensorflow/compiler/xla/pjrt/local_device_state.h similarity index 96% rename from tensorflow/compiler/xla/python/local_device_state.h rename to tensorflow/compiler/xla/pjrt/local_device_state.h index 5cd2c0014a0..eb25c37878f 100644 --- a/tensorflow/compiler/xla/python/local_device_state.h +++ b/tensorflow/compiler/xla/pjrt/local_device_state.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_ #include #include @@ -22,9 +22,9 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/python/event_pool.h" -#include "tensorflow/compiler/xla/python/semaphore.h" -#include "tensorflow/compiler/xla/python/worker_thread.h" +#include "tensorflow/compiler/xla/pjrt/event_pool.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/worker_thread.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/stream_executor.h" @@ -207,4 +207,4 @@ class LocalDeviceState { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_ diff --git a/tensorflow/compiler/xla/python/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc similarity index 99% rename from tensorflow/compiler/xla/python/nvidia_gpu_device.cc rename to tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc index 886ed697f4e..4863e5e8165 100644 --- a/tensorflow/compiler/xla/python/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" #ifdef NCCL_ENABLED #include "third_party/nccl/nccl.h" diff --git a/tensorflow/compiler/xla/python/nvidia_gpu_device.h b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h similarity index 87% rename from tensorflow/compiler/xla/python/nvidia_gpu_device.h rename to tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h index 2f9922454fa..bf59ddef3a9 100644 --- a/tensorflow/compiler/xla/python/nvidia_gpu_device.h +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_ #include -#include "tensorflow/compiler/xla/python/distributed/client.h" -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/bfc_allocator.h" @@ -59,4 +59,4 @@ StatusOr> GetNvidiaGpuClient( } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_ diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc similarity index 99% rename from tensorflow/compiler/xla/python/local_client.cc rename to tensorflow/compiler/xla/pjrt/pjrt_client.cc index f2acd0d6398..80fd0e0b658 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -62,7 +62,7 @@ limitations under the License. // See the comment on LocalDeviceState::AllocationModel for a discussion of the // different allocation semantics on CPU, GPU, and TPU. -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include #include @@ -83,10 +83,10 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h" -#include "tensorflow/compiler/xla/python/event_pool.h" -#include "tensorflow/compiler/xla/python/local_device_state.h" -#include "tensorflow/compiler/xla/python/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h" +#include "tensorflow/compiler/xla/pjrt/event_pool.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h similarity index 99% rename from tensorflow/compiler/xla/python/local_client.h rename to tensorflow/compiler/xla/pjrt/pjrt_client.h index f09e70037d6..775b44c7073 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ #include #include @@ -29,8 +29,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/python/local_device_state.h" -#include "tensorflow/compiler/xla/python/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -681,4 +681,4 @@ class PjRtExecutable { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ diff --git a/tensorflow/compiler/xla/python/semaphore.cc b/tensorflow/compiler/xla/pjrt/semaphore.cc similarity index 97% rename from tensorflow/compiler/xla/python/semaphore.cc rename to tensorflow/compiler/xla/pjrt/semaphore.cc index 5926618bddc..c1df52acc61 100644 --- a/tensorflow/compiler/xla/python/semaphore.cc +++ b/tensorflow/compiler/xla/pjrt/semaphore.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/python/semaphore.h b/tensorflow/compiler/xla/pjrt/semaphore.h similarity index 92% rename from tensorflow/compiler/xla/python/semaphore.h rename to tensorflow/compiler/xla/pjrt/semaphore.h index 7d3e9ce6271..45345becf74 100644 --- a/tensorflow/compiler/xla/python/semaphore.h +++ b/tensorflow/compiler/xla/pjrt/semaphore.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_ #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/types.h" @@ -65,4 +65,4 @@ class Semaphore { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_ diff --git a/tensorflow/compiler/xla/python/semaphore_test.cc b/tensorflow/compiler/xla/pjrt/semaphore_test.cc similarity index 97% rename from tensorflow/compiler/xla/python/semaphore_test.cc rename to tensorflow/compiler/xla/pjrt/semaphore_test.cc index 5ef59618b8b..56f7e8c9a05 100644 --- a/tensorflow/compiler/xla/python/semaphore_test.cc +++ b/tensorflow/compiler/xla/pjrt/semaphore_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" #include "absl/synchronization/notification.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/xla/python/tracked_device_buffer.cc b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc similarity index 98% rename from tensorflow/compiler/xla/python/tracked_device_buffer.cc rename to tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc index 5c6dbbf3289..32ca4e4550c 100644 --- a/tensorflow/compiler/xla/python/tracked_device_buffer.cc +++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include #include #include "absl/synchronization/mutex.h" -#include "tensorflow/compiler/xla/python/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/stream_executor/device_memory.h" diff --git a/tensorflow/compiler/xla/python/tracked_device_buffer.h b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.h similarity index 97% rename from tensorflow/compiler/xla/python/tracked_device_buffer.h rename to tensorflow/compiler/xla/pjrt/tracked_device_buffer.h index 27e7de6e2c2..562cb2f913e 100644 --- a/tensorflow/compiler/xla/python/tracked_device_buffer.h +++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TRACKED_DEVICE_BUFFER_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_TRACKED_DEVICE_BUFFER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_ #include #include "absl/container/flat_hash_set.h" -#include "tensorflow/compiler/xla/python/event_pool.h" -#include "tensorflow/compiler/xla/python/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/event_pool.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape.h" @@ -257,4 +257,4 @@ void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer, } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TRACKED_DEVICE_BUFFER_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_ diff --git a/tensorflow/compiler/xla/python/tracked_device_buffer_test.cc b/tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc similarity index 98% rename from tensorflow/compiler/xla/python/tracked_device_buffer_test.cc rename to tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc index 354176654af..9373b57e7d1 100644 --- a/tensorflow/compiler/xla/python/tracked_device_buffer_test.cc +++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include diff --git a/tensorflow/compiler/xla/python/worker_thread.cc b/tensorflow/compiler/xla/pjrt/worker_thread.cc similarity index 96% rename from tensorflow/compiler/xla/python/worker_thread.cc rename to tensorflow/compiler/xla/pjrt/worker_thread.cc index d3fb02023a5..e8194534aef 100644 --- a/tensorflow/compiler/xla/python/worker_thread.cc +++ b/tensorflow/compiler/xla/pjrt/worker_thread.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/worker_thread.h" +#include "tensorflow/compiler/xla/pjrt/worker_thread.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/worker_thread.h b/tensorflow/compiler/xla/pjrt/worker_thread.h similarity index 90% rename from tensorflow/compiler/xla/python/worker_thread.h rename to tensorflow/compiler/xla/pjrt/worker_thread.h index 598f7b1d4ae..4fd2baa4cda 100644 --- a/tensorflow/compiler/xla/python/worker_thread.h +++ b/tensorflow/compiler/xla/pjrt/worker_thread.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_ #include #include @@ -51,4 +51,4 @@ class WorkerThread { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_ diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 3eb93f9559e..8c6bc84cf8e 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -1,7 +1,5 @@ load("//tensorflow/core/platform:build_config.bzl", "pyx_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_py_test_deps") -load("//tensorflow:tensorflow.bzl", "py_test", "tf_cc_test") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "pybind_extension") @@ -78,16 +76,6 @@ py_test( ] + xla_py_test_deps(), ) -cc_library( - name = "worker_thread", - srcs = ["worker_thread.cc"], - hdrs = ["worker_thread.h"], - deps = [ - "//tensorflow/core:lib", - "@com_google_absl//absl/synchronization", - ], -) - cc_library( name = "types", srcs = ["types.cc"], @@ -99,7 +87,6 @@ cc_library( features = ["-use_header_modules"], deps = [ ":bfloat16", - ":local_client", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -107,6 +94,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/core:lib", "//third_party/py/numpy:headers", "@com_google_absl//absl/container:flat_hash_map", @@ -116,148 +104,6 @@ cc_library( ], ) -cc_library( - name = "event_pool", - srcs = ["event_pool.cc"], - hdrs = ["event_pool.h"], - deps = [ - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - ], -) - -cc_library( - name = "semaphore", - srcs = ["semaphore.cc"], - hdrs = ["semaphore.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/synchronization", - ], -) - -tf_cc_test( - name = "semaphore_test", - srcs = ["semaphore_test.cc"], - deps = [ - ":semaphore", - "//tensorflow/compiler/xla:test", - "//tensorflow/core:lib", - "//tensorflow/core:test_main", - "@com_google_absl//absl/synchronization", - ], -) - -cc_library( - name = "tracked_device_buffer", - srcs = ["tracked_device_buffer.cc"], - hdrs = ["tracked_device_buffer.h"], - deps = [ - ":event_pool", - ":local_device_state", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/core:lib", - "//tensorflow/stream_executor:device_memory", - "//tensorflow/stream_executor:device_memory_allocator", - "//tensorflow/stream_executor:event", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/synchronization", - ], -) - -tf_cc_test( - name = "tracked_device_buffer_test", - srcs = ["tracked_device_buffer_test.cc"], - deps = [ - ":tracked_device_buffer", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/service:cpu_plugin", - "//tensorflow/core:test_main", - "//tensorflow/stream_executor:device_memory", - "//tensorflow/stream_executor:device_memory_allocator", - ], -) - -cc_library( - name = "local_device_state", - srcs = ["local_device_state.cc"], - hdrs = ["local_device_state.h"], - deps = [ - ":event_pool", - ":semaphore", - ":worker_thread", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", - "//tensorflow/stream_executor:event", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - ], -) - -cc_library( - name = "local_client", - srcs = ["local_client.cc"], - hdrs = ["local_client.h"], - visibility = ["//tensorflow/compiler/xla:friends"], - deps = [ - ":event_pool", - ":local_device_state", - ":tracked_device_buffer", - "//tensorflow/compiler/xla:cpu_function_runtime", - "//tensorflow/compiler/xla:executable_run_options", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/python/distributed:protocol_proto_cc", - "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:executable", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:maybe_owning_device_memory", - "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", - "//tensorflow/core:allocator", - "//tensorflow/core:lib", - "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/stream_executor:event", - "//tensorflow/stream_executor:stream", - "//tensorflow/stream_executor/host:host_platform_id", - "//tensorflow/stream_executor/lib", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - ], -) - cc_library( name = "python_ref_manager", srcs = ["python_ref_manager.cc"], @@ -322,10 +168,10 @@ cc_library( ], features = ["-use_header_modules"], deps = [ - ":local_client", - ":tracked_device_buffer", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:tracked_device_buffer", "//tensorflow/stream_executor:device_memory", "//tensorflow/stream_executor:platform", "//tensorflow/stream_executor/cuda:cuda_platform_id", @@ -340,37 +186,6 @@ cc_library( ], ) -cc_library( - name = "cpu_device", - srcs = ["cpu_device.cc"], - hdrs = ["cpu_device.h"], - deps = [ - ":local_client", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/service:platform_util", - ], -) - -cc_library( - name = "nvidia_gpu_device", - srcs = ["nvidia_gpu_device.cc"], - hdrs = ["nvidia_gpu_device.h"], - copts = if_cuda(["-DNCCL_ENABLED=1"]), - deps = [ - ":local_client", - "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/python/distributed:client", - "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla:util", - "//tensorflow/core/common_runtime:bfc_allocator", - "//tensorflow/core/common_runtime/gpu:gpu_mem_allocator", - "//tensorflow/stream_executor:tf_allocator_adapter", - ] + if_cuda(["@local_config_nccl//:nccl"]), -) - config_setting( name = "enable_gpu", values = {"define": "xla_python_enable_gpu=true"}, @@ -389,11 +204,7 @@ pybind_extension( module_name = "xla_extension", deps = [ ":bfloat16", - ":cpu_device", ":dlpack", - ":local_client", - ":nvidia_gpu_device", - ":tracked_device_buffer", ":python_ref_manager", ":types", "@com_google_absl//absl/base", @@ -423,9 +234,13 @@ pybind_extension( "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/compiler/xla/client/lib:svd", - "//tensorflow/compiler/xla/python/distributed", - "//tensorflow/compiler/xla/python/distributed:client", - "//tensorflow/compiler/xla/python/distributed:service", + "//tensorflow/compiler/xla/pjrt:cpu_device", + "//tensorflow/compiler/xla/pjrt:nvidia_gpu_device", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:tracked_device_buffer", + "//tensorflow/compiler/xla/pjrt/distributed", + "//tensorflow/compiler/xla/pjrt/distributed:client", + "//tensorflow/compiler/xla/pjrt/distributed:service", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:hlo", @@ -454,25 +269,3 @@ pybind_extension( "//conditions:default": [], }), ) - -tf_cc_test( - name = "gpu_multistream_test", - srcs = ["gpu_multistream_test.cc"], - tags = [ - # TODO(phawkins): figure out TF test infra such that this only runs under GPU. - "no_oss", - "requires-gpu-nvidia", - ], - deps = [ - ":local_client", - ":nvidia_gpu_device", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:gpu_plugin", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/core:lib", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:random", - ], -) diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 31f51d70937..d37d480607a 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "include/dlpack/dlpack.h" // from @dlpack -#include "tensorflow/compiler/xla/python/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/stream_executor/cuda/cuda_platform_id.h" diff --git a/tensorflow/compiler/xla/python/dlpack.h b/tensorflow/compiler/xla/python/dlpack.h index 9d8965ac43d..6766bbe93b1 100644 --- a/tensorflow/compiler/xla/python/dlpack.h +++ b/tensorflow/compiler/xla/python/dlpack.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ #include "pybind11/pybind11.h" -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index b5f1a831d4a..c460cc36f08 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -19,8 +19,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/python:local_client", - "//tensorflow/compiler/xla/python:semaphore", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:semaphore", "//tensorflow/compiler/xla/python/tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:grpc_tpu_driver", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index fe2cddd75ef..e78f04ff980 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -24,7 +24,7 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/python/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index f2c792d2a20..4c45df181db 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -24,7 +24,7 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h index 4ed4e9cb7f8..673f403d91e 100644 --- a/tensorflow/compiler/xla/python/types.h +++ b/tensorflow/compiler/xla/python/types.h @@ -26,7 +26,7 @@ limitations under the License. #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 206c304abbb..8bd771436c5 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -39,14 +39,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/pjrt/cpu_device.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/bfloat16.h" -#include "tensorflow/compiler/xla/python/cpu_device.h" -#include "tensorflow/compiler/xla/python/distributed/client.h" -#include "tensorflow/compiler/xla/python/distributed/distributed.h" -#include "tensorflow/compiler/xla/python/distributed/service.h" #include "tensorflow/compiler/xla/python/dlpack.h" -#include "tensorflow/compiler/xla/python/local_client.h" -#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"