[XLA:Python] Drop and/or minimize some streamexecutor dependencies from the XLA:Python TPU driver code.
In principle, the RPC-based TPU driver does not depend on StreamExecutor; any such dependency is spurious. This change prunes most of the StreamExecutor dependencies, leaving only the one in xla/service:computation_placer, which is instead pruned down to a more specific build target. PiperOrigin-RevId: 352799982 Change-Id: Ieb54fb96e5843ac336abe16707a7c57def7edb3a
This commit is contained in:
parent
43f5aec62f
commit
ceaeaac6dd
@ -101,7 +101,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
|
@ -23,7 +23,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
|
||||
namespace stream_executor {
|
||||
|
||||
// Forward-declared to avoid StreamExecutor dependency.
|
||||
class DeviceMemoryAllocator;
|
||||
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace xla {
|
||||
|
||||
|
@ -36,7 +36,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:recording_tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:tpu_driver_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/core/framework:allocator",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:env",
|
||||
|
@ -28,7 +28,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
|
||||
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
@ -2883,6 +2883,10 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:platform",
|
||||
"//tensorflow/stream_executor/cuda:cuda_platform_id",
|
||||
"//tensorflow/stream_executor/host:host_platform_id",
|
||||
"//tensorflow/stream_executor/rocm:rocm_platform_id",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
@ -2980,7 +2984,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
|
@ -33,7 +33,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
|
||||
#include "tensorflow/stream_executor/host/host_platform_id.h"
|
||||
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
||||
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
|
@ -28,8 +28,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
|
@ -273,7 +273,8 @@ Status ExecuteWrapperAfterExecution(
|
||||
&stream->parent()->GetDeviceDescription();
|
||||
std::shared_ptr<HloExecutionProfile> profile = state.profile_ptr;
|
||||
stream->ThenDoHostCallback([profile, device_description]() {
|
||||
XLA_LOG_LINES(tensorflow::INFO, profile->ToString(*device_description));
|
||||
XLA_LOG_LINES(tensorflow::INFO,
|
||||
profile->ToString(device_description->clock_rate_ghz()));
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile_data.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
@ -106,8 +105,6 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
|
||||
// down how much time each HLO took.
|
||||
class HloExecutionProfile {
|
||||
public:
|
||||
using DeviceDescription = se::DeviceDescription;
|
||||
|
||||
HloExecutionProfile(const HloProfilePrinterData* hlo_profile_printer_data,
|
||||
const HloProfileIndexMap* hlo_profile_index_map);
|
||||
|
||||
@ -149,9 +146,9 @@ class HloExecutionProfile {
|
||||
// frequency, and the effective throughput given the provided cost_analysis
|
||||
// for the operations in a given computation. Returns an empty string if it
|
||||
// wasn't possible to generate a printable version.
|
||||
string ToString(const DeviceDescription& device_description) const {
|
||||
string ToString(float clock_rate_ghz) const {
|
||||
return PrintHloProfile(hlo_profile_printer_data_, profile_counters_.data(),
|
||||
device_description.clock_rate_ghz());
|
||||
clock_rate_ghz);
|
||||
}
|
||||
|
||||
std::vector<int64>* mutable_profile_counters() { return &profile_counters_; }
|
||||
|
@ -64,8 +64,11 @@ TEST_F(HloExecutionProfileTest, Basic) {
|
||||
execution_profile.SetCyclesTakenBy(add_instruction, add_cycles);
|
||||
execution_profile.SetCyclesTakenBy(dot_instruction, dot_cycles);
|
||||
|
||||
EXPECT_THAT(execution_profile.ToString(
|
||||
backend().default_stream_executor()->GetDeviceDescription()),
|
||||
float clock_rate_ghz = backend()
|
||||
.default_stream_executor()
|
||||
->GetDeviceDescription()
|
||||
.clock_rate_ghz();
|
||||
EXPECT_THAT(execution_profile.ToString(clock_rate_ghz),
|
||||
AllOf(ContainsRegex(StrCat(dot_cycles, " cycles.*%",
|
||||
dot_instruction->name())),
|
||||
ContainsRegex(StrCat(add_cycles, " cycles.*%",
|
||||
|
@ -181,8 +181,8 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
|
||||
TF_ASSERT_OK(stream_ptr->BlockHostUntilDone());
|
||||
(void)execution_result;
|
||||
|
||||
*profile_output =
|
||||
hlo_execution_profile.ToString(executor->GetDeviceDescription());
|
||||
*profile_output = hlo_execution_profile.ToString(
|
||||
executor->GetDeviceDescription().clock_rate_ghz());
|
||||
|
||||
XLA_VLOG_LINES(4, *profile_output);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user