Split non-Python PJRT classes into their own directory.

PiperOrigin-RevId: 309424461
Change-Id: I471ee7ae98bc3be7e0540859ac111cce4ba5d6b5
This commit is contained in:
A. Unique TensorFlower 2020-05-01 09:56:38 -07:00 committed by TensorFlower Gardener
parent 1d8770c6e2
commit 5a6996954e
42 changed files with 322 additions and 317 deletions

View File

@ -0,0 +1,212 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "worker_thread",
srcs = ["worker_thread.cc"],
hdrs = ["worker_thread.h"],
deps = [
"//tensorflow/core:lib",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "event_pool",
srcs = ["event_pool.cc"],
hdrs = ["event_pool.h"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "semaphore",
srcs = ["semaphore.cc"],
hdrs = ["semaphore.h"],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
],
)
tf_cc_test(
name = "semaphore_test",
srcs = ["semaphore_test.cc"],
deps = [
":semaphore",
"//tensorflow/compiler/xla:test",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "tracked_device_buffer",
srcs = ["tracked_device_buffer.cc"],
hdrs = ["tracked_device_buffer.h"],
deps = [
":event_pool",
":local_device_state",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
"//tensorflow/stream_executor:device_memory",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor:event",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/synchronization",
],
)
tf_cc_test(
name = "tracked_device_buffer_test",
srcs = ["tracked_device_buffer_test.cc"],
deps = [
":tracked_device_buffer",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/core:test_main",
"//tensorflow/stream_executor:device_memory",
"//tensorflow/stream_executor:device_memory_allocator",
],
)
cc_library(
name = "local_device_state",
srcs = ["local_device_state.cc"],
hdrs = ["local_device_state.h"],
deps = [
":event_pool",
":semaphore",
":worker_thread",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor",
"//tensorflow/stream_executor:event",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "pjrt_client",
srcs = ["pjrt_client.cc"],
hdrs = ["pjrt_client.h"],
visibility = ["//tensorflow/compiler/xla:friends"],
deps = [
":event_pool",
":local_device_state",
":tracked_device_buffer",
"//tensorflow/compiler/xla:cpu_function_runtime",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/pjrt/distributed:protocol_proto_cc",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
"//tensorflow/core:allocator",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:stream",
"//tensorflow/stream_executor/host:host_platform_id",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "cpu_device",
srcs = ["cpu_device.cc"],
hdrs = ["cpu_device.h"],
deps = [
":pjrt_client",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/service:platform_util",
],
)
cc_library(
name = "nvidia_gpu_device",
srcs = ["nvidia_gpu_device.cc"],
hdrs = ["nvidia_gpu_device.h"],
copts = if_cuda(["-DNCCL_ENABLED=1"]),
deps = [
":pjrt_client",
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/pjrt/distributed:client",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/core/common_runtime:bfc_allocator",
"//tensorflow/core/common_runtime/gpu:gpu_mem_allocator",
"//tensorflow/stream_executor:tf_allocator_adapter",
] + if_cuda(["@local_config_nccl//:nccl"]),
)
tf_cc_test(
name = "gpu_multistream_test",
srcs = ["gpu_multistream_test.cc"],
tags = [
# TODO(phawkins): figure out TF test infra such that this only runs under GPU.
"no_oss",
"requires-gpu-nvidia",
],
deps = [
":nvidia_gpu_device",
":pjrt_client",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:random",
],
)

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/cpu_device.h"
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/service/platform_util.h"

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_
#include <memory>
#include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
@ -32,4 +32,4 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/distributed/client.h"
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
#include <chrono> // NOLINT
#include "tensorflow/compiler/xla/python/distributed/protocol.h"
#include "tensorflow/compiler/xla/python/distributed/util.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
#include "tensorflow/compiler/xla/pjrt/distributed/util.h"
namespace xla {

View File

@ -13,15 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_
#include <memory>
#include "grpcpp/channel.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "tensorflow/compiler/xla/python/distributed/protocol.grpc.pb.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/platform/env.h"
@ -47,4 +47,4 @@ class DistributedRuntimeClient {
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_

View File

@ -15,10 +15,10 @@ limitations under the License.
#include "grpcpp/security/server_credentials.h"
#include "absl/time/time.h"
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/python/distributed/client.h"
#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h"
#include "tensorflow/compiler/xla/python/distributed/service.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/distributed/distributed.h"
#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h"
#include "grpcpp/grpcpp.h"

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_
#include <memory>
#include <string>
#include "tensorflow/compiler/xla/python/distributed/client.h"
#include "tensorflow/compiler/xla/python/distributed/service.h"
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
@ -43,4 +43,4 @@ std::shared_ptr<DistributedRuntimeClient> GetDistributedRuntimeClient(
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/distributed/key_value_store.h"
#include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h"
namespace xla {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_
#include "grpcpp/grpcpp.h"
#include "absl/base/thread_annotations.h"
@ -50,4 +50,4 @@ class KeyValueStore {
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_
namespace xla {
@ -22,4 +22,4 @@ static constexpr int kDistributedRuntimeProtocolVersion = 1;
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_

View File

@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/distributed/service.h"
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
#include "tensorflow/compiler/xla/python/distributed/protocol.h"
#include "tensorflow/compiler/xla/python/distributed/util.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
#include "tensorflow/compiler/xla/pjrt/distributed/util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/util.h"

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_
#include "absl/time/time.h"
#include "tensorflow/compiler/xla/python/distributed/key_value_store.h"
#include "tensorflow/compiler/xla/python/distributed/protocol.grpc.pb.h"
#include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
@ -98,4 +98,4 @@ void BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_

View File

@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/distributed/service.h"
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_
#include "grpcpp/support/status.h"
#include "tensorflow/compiler/xla/status.h"
@ -41,4 +41,4 @@ inline ::grpc::Status ToGrpcStatus(const Status& s) {
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/event_pool.h"
#include "tensorflow/compiler/xla/pjrt/event_pool.h"
#include "absl/memory/memory.h"
#include "absl/synchronization/mutex.h"

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_
#include <memory>
#include <stack>
@ -87,4 +87,4 @@ class EventPool {
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_

View File

@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h"
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/platform/random.h"

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include <memory>
#include <vector>

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_
#include <memory>
#include <random>
@ -22,9 +22,9 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/python/event_pool.h"
#include "tensorflow/compiler/xla/python/semaphore.h"
#include "tensorflow/compiler/xla/python/worker_thread.h"
#include "tensorflow/compiler/xla/pjrt/event_pool.h"
#include "tensorflow/compiler/xla/pjrt/semaphore.h"
#include "tensorflow/compiler/xla/pjrt/worker_thread.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/core/platform/stream_executor.h"
@ -207,4 +207,4 @@ class LocalDeviceState {
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h"
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
#ifdef NCCL_ENABLED
#include "third_party/nccl/nccl.h"

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_
#include <memory>
#include "tensorflow/compiler/xla/python/distributed/client.h"
#include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/bfc_allocator.h"
@ -59,4 +59,4 @@ StatusOr<std::shared_ptr<PjRtClient>> GetNvidiaGpuClient(
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_

View File

@ -62,7 +62,7 @@ limitations under the License.
// See the comment on LocalDeviceState::AllocationModel for a discussion of the
// different allocation semantics on CPU, GPU, and TPU.
#include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include <cstddef>
#include <memory>
@ -83,10 +83,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h"
#include "tensorflow/compiler/xla/python/event_pool.h"
#include "tensorflow/compiler/xla/python/local_device_state.h"
#include "tensorflow/compiler/xla/python/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
#include "tensorflow/compiler/xla/pjrt/event_pool.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
#include <memory>
#include <string>
@ -29,8 +29,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/python/local_device_state.h"
#include "tensorflow/compiler/xla/python/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
@ -681,4 +681,4 @@ class PjRtExecutable {
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/semaphore.h"
#include "tensorflow/compiler/xla/pjrt/semaphore.h"
#include "tensorflow/core/platform/logging.h"

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/types.h"
@ -65,4 +65,4 @@ class Semaphore {
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/semaphore.h"
#include "tensorflow/compiler/xla/pjrt/semaphore.h"
#include "absl/synchronization/notification.h"
#include "tensorflow/compiler/xla/test.h"

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include <iterator>
#include <memory>
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/python/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/stream_executor/device_memory.h"

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TRACKED_DEVICE_BUFFER_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_TRACKED_DEVICE_BUFFER_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/python/event_pool.h"
#include "tensorflow/compiler/xla/python/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/event_pool.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape.h"
@ -257,4 +257,4 @@ void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer,
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TRACKED_DEVICE_BUFFER_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include <memory>

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/worker_thread.h"
#include "tensorflow/compiler/xla/pjrt/worker_thread.h"
namespace xla {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_
#include <functional>
#include <memory>
@ -51,4 +51,4 @@ class WorkerThread {
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_
#endif // TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_

View File

@ -1,7 +1,5 @@
load("//tensorflow/core/platform:build_config.bzl", "pyx_library")
load("//tensorflow/compiler/xla:xla.bzl", "xla_py_test_deps")
load("//tensorflow:tensorflow.bzl", "py_test", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "pybind_extension")
@ -78,16 +76,6 @@ py_test(
] + xla_py_test_deps(),
)
cc_library(
name = "worker_thread",
srcs = ["worker_thread.cc"],
hdrs = ["worker_thread.h"],
deps = [
"//tensorflow/core:lib",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "types",
srcs = ["types.cc"],
@ -99,7 +87,6 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":bfloat16",
":local_client",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@ -107,6 +94,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/core:lib",
"//third_party/py/numpy:headers",
"@com_google_absl//absl/container:flat_hash_map",
@ -116,148 +104,6 @@ cc_library(
],
)
cc_library(
name = "event_pool",
srcs = ["event_pool.cc"],
hdrs = ["event_pool.h"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "semaphore",
srcs = ["semaphore.cc"],
hdrs = ["semaphore.h"],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
],
)
tf_cc_test(
name = "semaphore_test",
srcs = ["semaphore_test.cc"],
deps = [
":semaphore",
"//tensorflow/compiler/xla:test",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "tracked_device_buffer",
srcs = ["tracked_device_buffer.cc"],
hdrs = ["tracked_device_buffer.h"],
deps = [
":event_pool",
":local_device_state",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
"//tensorflow/stream_executor:device_memory",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor:event",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/synchronization",
],
)
tf_cc_test(
name = "tracked_device_buffer_test",
srcs = ["tracked_device_buffer_test.cc"],
deps = [
":tracked_device_buffer",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/core:test_main",
"//tensorflow/stream_executor:device_memory",
"//tensorflow/stream_executor:device_memory_allocator",
],
)
cc_library(
name = "local_device_state",
srcs = ["local_device_state.cc"],
hdrs = ["local_device_state.h"],
deps = [
":event_pool",
":semaphore",
":worker_thread",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor",
"//tensorflow/stream_executor:event",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "local_client",
srcs = ["local_client.cc"],
hdrs = ["local_client.h"],
visibility = ["//tensorflow/compiler/xla:friends"],
deps = [
":event_pool",
":local_device_state",
":tracked_device_buffer",
"//tensorflow/compiler/xla:cpu_function_runtime",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/python/distributed:protocol_proto_cc",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
"//tensorflow/core:allocator",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:stream",
"//tensorflow/stream_executor/host:host_platform_id",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "python_ref_manager",
srcs = ["python_ref_manager.cc"],
@ -322,10 +168,10 @@ cc_library(
],
features = ["-use_header_modules"],
deps = [
":local_client",
":tracked_device_buffer",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt:tracked_device_buffer",
"//tensorflow/stream_executor:device_memory",
"//tensorflow/stream_executor:platform",
"//tensorflow/stream_executor/cuda:cuda_platform_id",
@ -340,37 +186,6 @@ cc_library(
],
)
cc_library(
name = "cpu_device",
srcs = ["cpu_device.cc"],
hdrs = ["cpu_device.h"],
deps = [
":local_client",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/service:platform_util",
],
)
cc_library(
name = "nvidia_gpu_device",
srcs = ["nvidia_gpu_device.cc"],
hdrs = ["nvidia_gpu_device.h"],
copts = if_cuda(["-DNCCL_ENABLED=1"]),
deps = [
":local_client",
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/python/distributed:client",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/core/common_runtime:bfc_allocator",
"//tensorflow/core/common_runtime/gpu:gpu_mem_allocator",
"//tensorflow/stream_executor:tf_allocator_adapter",
] + if_cuda(["@local_config_nccl//:nccl"]),
)
config_setting(
name = "enable_gpu",
values = {"define": "xla_python_enable_gpu=true"},
@ -389,11 +204,7 @@ pybind_extension(
module_name = "xla_extension",
deps = [
":bfloat16",
":cpu_device",
":dlpack",
":local_client",
":nvidia_gpu_device",
":tracked_device_buffer",
":python_ref_manager",
":types",
"@com_google_absl//absl/base",
@ -423,9 +234,13 @@ pybind_extension(
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
"//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/compiler/xla/client/lib:svd",
"//tensorflow/compiler/xla/python/distributed",
"//tensorflow/compiler/xla/python/distributed:client",
"//tensorflow/compiler/xla/python/distributed:service",
"//tensorflow/compiler/xla/pjrt:cpu_device",
"//tensorflow/compiler/xla/pjrt:nvidia_gpu_device",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt:tracked_device_buffer",
"//tensorflow/compiler/xla/pjrt/distributed",
"//tensorflow/compiler/xla/pjrt/distributed:client",
"//tensorflow/compiler/xla/pjrt/distributed:service",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:custom_call_target_registry",
"//tensorflow/compiler/xla/service:hlo",
@ -454,25 +269,3 @@ pybind_extension(
"//conditions:default": [],
}),
)
tf_cc_test(
name = "gpu_multistream_test",
srcs = ["gpu_multistream_test.cc"],
tags = [
# TODO(phawkins): figure out TF test infra such that this only runs under GPU.
"no_oss",
"requires-gpu-nvidia",
],
deps = [
":local_client",
":nvidia_gpu_device",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:random",
],
)

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/compiler/xla/python/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"

View File

@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
namespace xla {

View File

@ -19,8 +19,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/python:local_client",
"//tensorflow/compiler/xla/python:semaphore",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt:semaphore",
"//tensorflow/compiler/xla/python/tpu_driver",
"//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver",
"//tensorflow/compiler/xla/python/tpu_driver:grpc_tpu_driver",

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/python/semaphore.h"
#include "tensorflow/compiler/xla/pjrt/semaphore.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/shape_util.h"

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "absl/synchronization/notification.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"

View File

@ -39,14 +39,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h"
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/bfloat16.h"
#include "tensorflow/compiler/xla/python/cpu_device.h"
#include "tensorflow/compiler/xla/python/distributed/client.h"
#include "tensorflow/compiler/xla/python/distributed/distributed.h"
#include "tensorflow/compiler/xla/python/distributed/service.h"
#include "tensorflow/compiler/xla/python/dlpack.h"
#include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"