[XLA:Python] Drop and/or minimize some streamexecutor dependencies from the XLA:Python TPU driver code.

In principle, the RPC-based TPU driver does not depend on StreamExecutor; any such dependency is spurious. This change prunes most of the StreamExecutor dependencies, leaving only the one in xla/service:computation_placer, which is instead pruned down to a more specific build target.

PiperOrigin-RevId: 352799982
Change-Id: Ieb54fb96e5843ac336abe16707a7c57def7edb3a
This commit is contained in:
Peter Hawkins 2021-01-20 08:38:07 -08:00 committed by TensorFlower Gardener
parent 43f5aec62f
commit ceaeaac6dd
11 changed files with 29 additions and 17 deletions

View File

@ -101,7 +101,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:optional",

View File

@ -23,7 +23,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/core/platform/threadpool.h"
namespace stream_executor {
// Forward-declared to avoid StreamExecutor dependency.
class DeviceMemoryAllocator;
} // namespace stream_executor
namespace xla {

View File

@ -36,7 +36,6 @@ cc_library(
"//tensorflow/compiler/xla/python/tpu_driver:recording_tpu_driver",
"//tensorflow/compiler/xla/python/tpu_driver:tpu_driver_proto_cc",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core/framework:allocator",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:env",

View File

@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"

View File

@ -2883,6 +2883,10 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core/platform:stream_executor_no_cuda",
"//tensorflow/stream_executor:platform",
"//tensorflow/stream_executor/cuda:cuda_platform_id",
"//tensorflow/stream_executor/host:host_platform_id",
"//tensorflow/stream_executor/rocm:rocm_platform_id",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@ -2980,7 +2984,6 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core/platform:stream_executor_no_cuda",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
],

View File

@ -33,7 +33,9 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/host/host_platform_id.h"
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
using absl::StrAppend;
using absl::StrCat;

View File

@ -28,8 +28,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/platform.h"
namespace xla {

View File

@ -273,7 +273,8 @@ Status ExecuteWrapperAfterExecution(
&stream->parent()->GetDeviceDescription();
std::shared_ptr<HloExecutionProfile> profile = state.profile_ptr;
stream->ThenDoHostCallback([profile, device_description]() {
XLA_LOG_LINES(tensorflow::INFO, profile->ToString(*device_description));
XLA_LOG_LINES(tensorflow::INFO,
profile->ToString(device_description->clock_rate_ghz()));
});
}

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_execution_profile_data.pb.h"
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@ -106,8 +105,6 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
// down how much time each HLO took.
class HloExecutionProfile {
public:
using DeviceDescription = se::DeviceDescription;
HloExecutionProfile(const HloProfilePrinterData* hlo_profile_printer_data,
const HloProfileIndexMap* hlo_profile_index_map);
@ -149,9 +146,9 @@ class HloExecutionProfile {
// frequency, and the effective throughput given the provided cost_analysis
// for the operations in a given computation. Returns an empty string if it
// wasn't possible to generate a printable version.
string ToString(const DeviceDescription& device_description) const {
string ToString(float clock_rate_ghz) const {
return PrintHloProfile(hlo_profile_printer_data_, profile_counters_.data(),
device_description.clock_rate_ghz());
clock_rate_ghz);
}
std::vector<int64>* mutable_profile_counters() { return &profile_counters_; }

View File

@ -64,8 +64,11 @@ TEST_F(HloExecutionProfileTest, Basic) {
execution_profile.SetCyclesTakenBy(add_instruction, add_cycles);
execution_profile.SetCyclesTakenBy(dot_instruction, dot_cycles);
EXPECT_THAT(execution_profile.ToString(
backend().default_stream_executor()->GetDeviceDescription()),
float clock_rate_ghz = backend()
.default_stream_executor()
->GetDeviceDescription()
.clock_rate_ghz();
EXPECT_THAT(execution_profile.ToString(clock_rate_ghz),
AllOf(ContainsRegex(StrCat(dot_cycles, " cycles.*%",
dot_instruction->name())),
ContainsRegex(StrCat(add_cycles, " cycles.*%",

View File

@ -181,8 +181,8 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
TF_ASSERT_OK(stream_ptr->BlockHostUntilDone());
(void)execution_result;
*profile_output =
hlo_execution_profile.ToString(executor->GetDeviceDescription());
*profile_output = hlo_execution_profile.ToString(
executor->GetDeviceDescription().clock_rate_ghz());
XLA_VLOG_LINES(4, *profile_output);
}