From ceaeaac6ddac3b4981f1434461ac48554451df2a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 20 Jan 2021 08:38:07 -0800 Subject: [PATCH] [XLA:Python] Drop and/or minimize some streamexecutor dependencies from the XLA:Python TPU driver code. In principle, the RPC-based TPU driver does not depend on StreamExecutor; any such dependency is spurious. This change prunes most of the StreamExecutor dependencies, leaving only the one in xla/service:computation_placer, which is instead pruned down to a more specific build target. PiperOrigin-RevId: 352799982 Change-Id: Ieb54fb96e5843ac336abe16707a7c57def7edb3a --- tensorflow/compiler/xla/client/BUILD | 2 +- .../compiler/xla/client/executable_build_options.h | 9 ++++++++- tensorflow/compiler/xla/python/tpu_driver/client/BUILD | 1 - .../compiler/xla/python/tpu_driver/client/tpu_client.h | 1 - tensorflow/compiler/xla/service/BUILD | 5 ++++- tensorflow/compiler/xla/service/computation_placer.cc | 4 +++- tensorflow/compiler/xla/service/computation_placer.h | 3 ++- tensorflow/compiler/xla/service/executable.cc | 3 ++- tensorflow/compiler/xla/service/hlo_execution_profile.h | 7 ++----- .../compiler/xla/service/hlo_execution_profile_test.cc | 7 +++++-- tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc | 4 ++-- 11 files changed, 29 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 6baeca85149..98128c75723 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -101,7 +101,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/stream_executor:device_memory_allocator", + "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 21f8e1fdb05..280860a9834 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -23,7 +23,14 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/core/platform/threadpool.h" + +namespace stream_executor { + +// Forward-declared to avoid StreamExecutor dependency. +class DeviceMemoryAllocator; + +} // namespace stream_executor namespace xla { diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index 25c39d68960..9020623be89 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -36,7 +36,6 @@ cc_library( "//tensorflow/compiler/xla/python/tpu_driver:recording_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:tpu_driver_proto_cc", "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core/framework:allocator", "//tensorflow/core/platform:casts", "//tensorflow/core/platform:env", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 85211b01bbb..c70d98b60e6 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" -#include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 86e2b347329..ea828c54553 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2883,6 +2883,10 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", "//tensorflow/core/platform:stream_executor_no_cuda", + "//tensorflow/stream_executor:platform", + "//tensorflow/stream_executor/cuda:cuda_platform_id", + "//tensorflow/stream_executor/host:host_platform_id", + "//tensorflow/stream_executor/rocm:rocm_platform_id", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -2980,7 +2984,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index f9aaa1a676e..bc206da7218 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -33,7 +33,9 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/cuda/cuda_platform_id.h" +#include "tensorflow/stream_executor/host/host_platform_id.h" +#include "tensorflow/stream_executor/rocm/rocm_platform_id.h" using absl::StrAppend; using absl::StrCat; diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index e6b591b7e23..217aa7ad002 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -28,8 +28,9 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/platform.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 42794fc995b..fd2f1dd54d2 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -273,7 +273,8 @@ Status ExecuteWrapperAfterExecution( &stream->parent()->GetDeviceDescription(); std::shared_ptr profile = state.profile_ptr; stream->ThenDoHostCallback([profile, device_description]() { - XLA_LOG_LINES(tensorflow::INFO, profile->ToString(*device_description)); + XLA_LOG_LINES(tensorflow::INFO, + profile->ToString(device_description->clock_rate_ghz())); }); } diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index 3538dce6f13..e000cca9ca7 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile_data.pb.h" #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -106,8 +105,6 @@ std::unique_ptr CreateHloProfilePrinterData( // down how much time each HLO took. class HloExecutionProfile { public: - using DeviceDescription = se::DeviceDescription; - HloExecutionProfile(const HloProfilePrinterData* hlo_profile_printer_data, const HloProfileIndexMap* hlo_profile_index_map); @@ -149,9 +146,9 @@ class HloExecutionProfile { // frequency, and the effective throughput given the provided cost_analysis // for the operations in a given computation. Returns an empty string if it // wasn't possible to generate a printable version. - string ToString(const DeviceDescription& device_description) const { + string ToString(float clock_rate_ghz) const { return PrintHloProfile(hlo_profile_printer_data_, profile_counters_.data(), - device_description.clock_rate_ghz()); + clock_rate_ghz); } std::vector* mutable_profile_counters() { return &profile_counters_; } diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index 57fc5ec0748..264d6aaa7f7 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -64,8 +64,11 @@ TEST_F(HloExecutionProfileTest, Basic) { execution_profile.SetCyclesTakenBy(add_instruction, add_cycles); execution_profile.SetCyclesTakenBy(dot_instruction, dot_cycles); - EXPECT_THAT(execution_profile.ToString( - backend().default_stream_executor()->GetDeviceDescription()), + float clock_rate_ghz = backend() + .default_stream_executor() + ->GetDeviceDescription() + .clock_rate_ghz(); + EXPECT_THAT(execution_profile.ToString(clock_rate_ghz), AllOf(ContainsRegex(StrCat(dot_cycles, " cycles.*%", dot_instruction->name())), ContainsRegex(StrCat(add_cycles, " cycles.*%", diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 1b8203e02a9..79ae09d2dc6 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -181,8 +181,8 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, TF_ASSERT_OK(stream_ptr->BlockHostUntilDone()); (void)execution_result; - *profile_output = - hlo_execution_profile.ToString(executor->GetDeviceDescription()); + *profile_output = hlo_execution_profile.ToString( + executor->GetDeviceDescription().clock_rate_ghz()); XLA_VLOG_LINES(4, *profile_output); }