XLA/CPU register hlo module to be captured.
PiperOrigin-RevId: 356868235 Change-Id: Iaf5faf2aa7e8776005ae8da4f1b0d9753e0731ce
This commit is contained in:
parent
05756fcc81
commit
a5ff8a7043
@ -341,6 +341,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:logical_buffer",
|
"//tensorflow/compiler/xla/service:logical_buffer",
|
||||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
|
"//tensorflow/compiler/xla/service:xla_debug_info_manager",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
"//tensorflow/core/platform:macros",
|
"//tensorflow/core/platform:macros",
|
||||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
|
||||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
@ -60,7 +61,14 @@ CpuExecutable::CpuExecutable(
|
|||||||
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
|
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
|
||||||
std::move(hlo_profile_index_map)),
|
std::move(hlo_profile_index_map)),
|
||||||
jit_(std::move(jit)),
|
jit_(std::move(jit)),
|
||||||
assignment_(std::move(assignment)) {
|
assignment_(std::move(assignment)),
|
||||||
|
module_name_(entry_function_name) {
|
||||||
|
if (assignment_) {
|
||||||
|
buffer_assignment_.reset(new BufferAssignmentProto(assignment_->ToProto()));
|
||||||
|
}
|
||||||
|
XlaDebugInfoManager::Get()->RegisterModule(module_name_, shared_module(),
|
||||||
|
buffer_assignment_);
|
||||||
|
|
||||||
// Resolve symbols in the constructor rather than at execution time to avoid
|
// Resolve symbols in the constructor rather than at execution time to avoid
|
||||||
// races because FindSymbol is not thread safe.
|
// races because FindSymbol is not thread safe.
|
||||||
llvm::Expected<llvm::JITEvaluatedSymbol> sym =
|
llvm::Expected<llvm::JITEvaluatedSymbol> sym =
|
||||||
@ -75,6 +83,11 @@ CpuExecutable::CpuExecutable(
|
|||||||
<< reinterpret_cast<void*>(compute_function_);
|
<< reinterpret_cast<void*>(compute_function_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CpuExecutable::~CpuExecutable() {
|
||||||
|
XlaDebugInfoManager::Get()->UnregisterModule(module_name_, shared_module(),
|
||||||
|
buffer_assignment_);
|
||||||
|
}
|
||||||
|
|
||||||
static StatusOr<MaybeOwningDeviceMemory> MemoryForAllocation(
|
static StatusOr<MaybeOwningDeviceMemory> MemoryForAllocation(
|
||||||
const BufferAllocation& allocation,
|
const BufferAllocation& allocation,
|
||||||
absl::Span<ExecutionInput const> arguments,
|
absl::Span<ExecutionInput const> arguments,
|
||||||
@ -151,6 +164,10 @@ Status CpuExecutable::ExecuteComputeFunction(
|
|||||||
|
|
||||||
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
|
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
|
||||||
|
|
||||||
|
XlaDebugInfoManager::Get()->OnModuleStart(module_name_);
|
||||||
|
auto cleanup = MakeCleanup(
|
||||||
|
[&]() { XlaDebugInfoManager::Get()->OnModuleStop(module_name_); });
|
||||||
|
|
||||||
size_t profile_counters_size =
|
size_t profile_counters_size =
|
||||||
hlo_execution_profile ? hlo_execution_profile->profile_counters().size()
|
hlo_execution_profile ? hlo_execution_profile->profile_counters().size()
|
||||||
: 0;
|
: 0;
|
||||||
|
@ -53,7 +53,7 @@ class CpuExecutable : public Executable {
|
|||||||
const string& entry_function_name,
|
const string& entry_function_name,
|
||||||
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
|
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
|
||||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
|
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
|
||||||
~CpuExecutable() override {}
|
~CpuExecutable() override;
|
||||||
|
|
||||||
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
@ -131,12 +131,17 @@ class CpuExecutable : public Executable {
|
|||||||
// Buffer assignment for the buffers we need to allocate.
|
// Buffer assignment for the buffers we need to allocate.
|
||||||
const std::unique_ptr<const BufferAssignment> assignment_;
|
const std::unique_ptr<const BufferAssignment> assignment_;
|
||||||
|
|
||||||
|
std::shared_ptr<const BufferAssignmentProto> buffer_assignment_;
|
||||||
|
|
||||||
// The LLVM IR, in string format, of the unoptimized module generated for this
|
// The LLVM IR, in string format, of the unoptimized module generated for this
|
||||||
// CpuExecutable. We save a string instead of an llvm::Module* because leaving
|
// CpuExecutable. We save a string instead of an llvm::Module* because leaving
|
||||||
// llvm::Module* in a singleton can cause the heap checker to emit false
|
// llvm::Module* in a singleton can cause the heap checker to emit false
|
||||||
// positives.
|
// positives.
|
||||||
string ir_module_string_;
|
string ir_module_string_;
|
||||||
|
|
||||||
|
// Unique identifier.
|
||||||
|
string module_name_;
|
||||||
|
|
||||||
ComputeFunctionType compute_function_;
|
ComputeFunctionType compute_function_;
|
||||||
|
|
||||||
// Entry function name for the computation.
|
// Entry function name for the computation.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user