Split non-Python PJRT classes into their own directory.
PiperOrigin-RevId: 309424461 Change-Id: I471ee7ae98bc3be7e0540859ac111cce4ba5d6b5
This commit is contained in:
parent
1d8770c6e2
commit
5a6996954e
212
tensorflow/compiler/xla/pjrt/BUILD
Normal file
212
tensorflow/compiler/xla/pjrt/BUILD
Normal file
@ -0,0 +1,212 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "worker_thread",
|
||||
srcs = ["worker_thread.cc"],
|
||||
hdrs = ["worker_thread.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "event_pool",
|
||||
srcs = ["event_pool.cc"],
|
||||
hdrs = ["event_pool.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "semaphore",
|
||||
srcs = ["semaphore.cc"],
|
||||
hdrs = ["semaphore.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "semaphore_test",
|
||||
srcs = ["semaphore_test.cc"],
|
||||
deps = [
|
||||
":semaphore",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tracked_device_buffer",
|
||||
srcs = ["tracked_device_buffer.cc"],
|
||||
hdrs = ["tracked_device_buffer.h"],
|
||||
deps = [
|
||||
":event_pool",
|
||||
":local_device_state",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tracked_device_buffer_test",
|
||||
srcs = ["tracked_device_buffer_test.cc"],
|
||||
deps = [
|
||||
":tracked_device_buffer",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "local_device_state",
|
||||
srcs = ["local_device_state.cc"],
|
||||
hdrs = ["local_device_state.h"],
|
||||
deps = [
|
||||
":event_pool",
|
||||
":semaphore",
|
||||
":worker_thread",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pjrt_client",
|
||||
srcs = ["pjrt_client.cc"],
|
||||
hdrs = ["pjrt_client.h"],
|
||||
visibility = ["//tensorflow/compiler/xla:friends"],
|
||||
deps = [
|
||||
":event_pool",
|
||||
":local_device_state",
|
||||
":tracked_device_buffer",
|
||||
"//tensorflow/compiler/xla:cpu_function_runtime",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/pjrt/distributed:protocol_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||
"//tensorflow/core:allocator",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/stream_executor/host:host_platform_id",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_device",
|
||||
srcs = ["cpu_device.cc"],
|
||||
hdrs = ["cpu_device.h"],
|
||||
deps = [
|
||||
":pjrt_client",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nvidia_gpu_device",
|
||||
srcs = ["nvidia_gpu_device.cc"],
|
||||
hdrs = ["nvidia_gpu_device.h"],
|
||||
copts = if_cuda(["-DNCCL_ENABLED=1"]),
|
||||
deps = [
|
||||
":pjrt_client",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/pjrt/distributed:client",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core/common_runtime:bfc_allocator",
|
||||
"//tensorflow/core/common_runtime/gpu:gpu_mem_allocator",
|
||||
"//tensorflow/stream_executor:tf_allocator_adapter",
|
||||
] + if_cuda(["@local_config_nccl//:nccl"]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gpu_multistream_test",
|
||||
srcs = ["gpu_multistream_test.cc"],
|
||||
tags = [
|
||||
# TODO(phawkins): figure out TF test infra such that this only runs under GPU.
|
||||
"no_oss",
|
||||
"requires-gpu-nvidia",
|
||||
],
|
||||
deps = [
|
||||
":nvidia_gpu_device",
|
||||
":pjrt_client",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:random",
|
||||
],
|
||||
)
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/cpu_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
@ -32,4 +32,4 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_
|
@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/distributed/client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
|
||||
#include "tensorflow/compiler/xla/python/distributed/protocol.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/util.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/util.h"
|
||||
|
||||
namespace xla {
|
||||
|
@ -13,15 +13,15 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "grpcpp/channel.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/protocol.grpc.pb.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
@ -47,4 +47,4 @@ class DistributedRuntimeClient {
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_
|
@ -15,10 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "grpcpp/security/server_credentials.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/client.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/service.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/distributed/distributed.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h"
|
||||
|
||||
#include "grpcpp/grpcpp.h"
|
||||
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/python/distributed/client.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/service.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
@ -43,4 +43,4 @@ std::shared_ptr<DistributedRuntimeClient> GetDistributedRuntimeClient(
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/distributed/key_value_store.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h"
|
||||
|
||||
namespace xla {
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_
|
||||
|
||||
#include "grpcpp/grpcpp.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
@ -50,4 +50,4 @@ class KeyValueStore {
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -22,4 +22,4 @@ static constexpr int kDistributedRuntimeProtocolVersion = 1;
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_
|
@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/distributed/service.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/python/distributed/protocol.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/util.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_
|
||||
|
||||
#include "absl/time/time.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/key_value_store.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/protocol.grpc.pb.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
@ -98,4 +98,4 @@ void BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_
|
@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/distributed/service.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_
|
||||
|
||||
#include "grpcpp/support/status.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
@ -41,4 +41,4 @@ inline ::grpc::Status ToGrpcStatus(const Status& s) {
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/event_pool.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/event_pool.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/synchronization/mutex.h"
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
@ -87,4 +87,4 @@ class EventPool {
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
@ -22,9 +22,9 @@ limitations under the License.
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/python/event_pool.h"
|
||||
#include "tensorflow/compiler/xla/python/semaphore.h"
|
||||
#include "tensorflow/compiler/xla/python/worker_thread.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/event_pool.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/semaphore.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/worker_thread.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
|
||||
@ -207,4 +207,4 @@ class LocalDeviceState {
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
|
||||
|
||||
#ifdef NCCL_ENABLED
|
||||
#include "third_party/nccl/nccl.h"
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/python/distributed/client.h"
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/bfc_allocator.h"
|
||||
|
||||
@ -59,4 +59,4 @@ StatusOr<std::shared_ptr<PjRtClient>> GetNvidiaGpuClient(
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_
|
@ -62,7 +62,7 @@ limitations under the License.
|
||||
// See the comment on LocalDeviceState::AllocationModel for a discussion of the
|
||||
// different allocation semantics on CPU, GPU, and TPU.
|
||||
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
@ -83,10 +83,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h"
|
||||
#include "tensorflow/compiler/xla/python/event_pool.h"
|
||||
#include "tensorflow/compiler/xla/python/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/python/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/event_pool.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@ -29,8 +29,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/python/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/python/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
@ -681,4 +681,4 @@ class PjRtExecutable {
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/semaphore.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/semaphore.h"
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -65,4 +65,4 @@ class Semaphore {
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/semaphore.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/semaphore.h"
|
||||
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/xla/python/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TRACKED_DEVICE_BUFFER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_TRACKED_DEVICE_BUFFER_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/python/event_pool.h"
|
||||
#include "tensorflow/compiler/xla/python/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/event_pool.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
@ -257,4 +257,4 @@ void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer,
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TRACKED_DEVICE_BUFFER_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
|
||||
#include <memory>
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/worker_thread.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/worker_thread.h"
|
||||
|
||||
namespace xla {
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
@ -51,4 +51,4 @@ class WorkerThread {
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_
|
@ -1,7 +1,5 @@
|
||||
load("//tensorflow/core/platform:build_config.bzl", "pyx_library")
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_py_test_deps")
|
||||
load("//tensorflow:tensorflow.bzl", "py_test", "tf_cc_test")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
@ -78,16 +76,6 @@ py_test(
|
||||
] + xla_py_test_deps(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "worker_thread",
|
||||
srcs = ["worker_thread.cc"],
|
||||
hdrs = ["worker_thread.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "types",
|
||||
srcs = ["types.cc"],
|
||||
@ -99,7 +87,6 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":bfloat16",
|
||||
":local_client",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
@ -107,6 +94,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/py/numpy:headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -116,148 +104,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "event_pool",
|
||||
srcs = ["event_pool.cc"],
|
||||
hdrs = ["event_pool.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "semaphore",
|
||||
srcs = ["semaphore.cc"],
|
||||
hdrs = ["semaphore.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "semaphore_test",
|
||||
srcs = ["semaphore_test.cc"],
|
||||
deps = [
|
||||
":semaphore",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tracked_device_buffer",
|
||||
srcs = ["tracked_device_buffer.cc"],
|
||||
hdrs = ["tracked_device_buffer.h"],
|
||||
deps = [
|
||||
":event_pool",
|
||||
":local_device_state",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tracked_device_buffer_test",
|
||||
srcs = ["tracked_device_buffer_test.cc"],
|
||||
deps = [
|
||||
":tracked_device_buffer",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "local_device_state",
|
||||
srcs = ["local_device_state.cc"],
|
||||
hdrs = ["local_device_state.h"],
|
||||
deps = [
|
||||
":event_pool",
|
||||
":semaphore",
|
||||
":worker_thread",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "local_client",
|
||||
srcs = ["local_client.cc"],
|
||||
hdrs = ["local_client.h"],
|
||||
visibility = ["//tensorflow/compiler/xla:friends"],
|
||||
deps = [
|
||||
":event_pool",
|
||||
":local_device_state",
|
||||
":tracked_device_buffer",
|
||||
"//tensorflow/compiler/xla:cpu_function_runtime",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/python/distributed:protocol_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||
"//tensorflow/core:allocator",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/stream_executor/host:host_platform_id",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "python_ref_manager",
|
||||
srcs = ["python_ref_manager.cc"],
|
||||
@ -322,10 +168,10 @@ cc_library(
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":local_client",
|
||||
":tracked_device_buffer",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/compiler/xla/pjrt:tracked_device_buffer",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
"//tensorflow/stream_executor:platform",
|
||||
"//tensorflow/stream_executor/cuda:cuda_platform_id",
|
||||
@ -340,37 +186,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_device",
|
||||
srcs = ["cpu_device.cc"],
|
||||
hdrs = ["cpu_device.h"],
|
||||
deps = [
|
||||
":local_client",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nvidia_gpu_device",
|
||||
srcs = ["nvidia_gpu_device.cc"],
|
||||
hdrs = ["nvidia_gpu_device.h"],
|
||||
copts = if_cuda(["-DNCCL_ENABLED=1"]),
|
||||
deps = [
|
||||
":local_client",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/python/distributed:client",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core/common_runtime:bfc_allocator",
|
||||
"//tensorflow/core/common_runtime/gpu:gpu_mem_allocator",
|
||||
"//tensorflow/stream_executor:tf_allocator_adapter",
|
||||
] + if_cuda(["@local_config_nccl//:nccl"]),
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "enable_gpu",
|
||||
values = {"define": "xla_python_enable_gpu=true"},
|
||||
@ -389,11 +204,7 @@ pybind_extension(
|
||||
module_name = "xla_extension",
|
||||
deps = [
|
||||
":bfloat16",
|
||||
":cpu_device",
|
||||
":dlpack",
|
||||
":local_client",
|
||||
":nvidia_gpu_device",
|
||||
":tracked_device_buffer",
|
||||
":python_ref_manager",
|
||||
":types",
|
||||
"@com_google_absl//absl/base",
|
||||
@ -423,9 +234,13 @@ pybind_extension(
|
||||
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
||||
"//tensorflow/compiler/xla/client/lib:sorting",
|
||||
"//tensorflow/compiler/xla/client/lib:svd",
|
||||
"//tensorflow/compiler/xla/python/distributed",
|
||||
"//tensorflow/compiler/xla/python/distributed:client",
|
||||
"//tensorflow/compiler/xla/python/distributed:service",
|
||||
"//tensorflow/compiler/xla/pjrt:cpu_device",
|
||||
"//tensorflow/compiler/xla/pjrt:nvidia_gpu_device",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/compiler/xla/pjrt:tracked_device_buffer",
|
||||
"//tensorflow/compiler/xla/pjrt/distributed",
|
||||
"//tensorflow/compiler/xla/pjrt/distributed:client",
|
||||
"//tensorflow/compiler/xla/pjrt/distributed:service",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
@ -454,25 +269,3 @@ pybind_extension(
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gpu_multistream_test",
|
||||
srcs = ["gpu_multistream_test.cc"],
|
||||
tags = [
|
||||
# TODO(phawkins): figure out TF test infra such that this only runs under GPU.
|
||||
"no_oss",
|
||||
"requires-gpu-nvidia",
|
||||
],
|
||||
deps = [
|
||||
":local_client",
|
||||
":nvidia_gpu_device",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:random",
|
||||
],
|
||||
)
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "include/dlpack/dlpack.h" // from @dlpack
|
||||
#include "tensorflow/compiler/xla/python/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
|
@ -19,8 +19,8 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/python:local_client",
|
||||
"//tensorflow/compiler/xla/python:semaphore",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/compiler/xla/pjrt:semaphore",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:grpc_tpu_driver",
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/python/semaphore.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/semaphore.h"
|
||||
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
|
||||
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
@ -39,14 +39,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/bfloat16.h"
|
||||
#include "tensorflow/compiler/xla/python/cpu_device.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/client.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/distributed.h"
|
||||
#include "tensorflow/compiler/xla/python/distributed/service.h"
|
||||
#include "tensorflow/compiler/xla/python/dlpack.h"
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h"
|
||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
|
Loading…
x
Reference in New Issue
Block a user