Make profiler C APIs portable by serializing opaque protobufs into a user provided buffer.
WHAT * Make profiler APIs in libtpu portable. * CollectData() takes in a user sized buffer that is big enough to contain a serialized XSpace from the TPU driver. WHY * Minimize number of APIs to maintain. * Minimize number of serialize and deserializations of XSpace. * Eliminate incompatibilities and crashes as a result of passing protobufs at shared library boundaries. * Untangle ownership and lifetime of resources; buffer for serializing is used and owned by the client, collected XSpace is owned by driver until after serialization. Misc clean ups: * Fix mismatch of TpuProfiler_Free definition vs declaration. * Fixed a leak in TpuProfiler_Create() on error conditions. Note: TPU driver refers to libtpu. PiperOrigin-RevId: 355299692 Change-Id: Ie37c295c20c29bb511e9e969d785de57fe1446fd
This commit is contained in:
parent
fe5dad047b
commit
b70b22739c
@ -58,16 +58,22 @@ class TpuTracer : public ProfilerInterface {
|
||||
};
|
||||
|
||||
TpuTracer::TpuTracer() {
|
||||
tpu_profiler_ = tpu::OpsApiFn()->TpuProfiler_CreateFn();
|
||||
StatusHelper status;
|
||||
tpu::OpsApiFn()->TpuProfiler_CreateFn(&tpu_profiler_, status.c_status);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status.status().error_message();
|
||||
}
|
||||
}
|
||||
|
||||
TpuTracer::~TpuTracer() { tpu::OpsApiFn()->TpuProfiler_FreeFn(tpu_profiler_); }
|
||||
TpuTracer::~TpuTracer() {
|
||||
tpu::OpsApiFn()->TpuProfiler_DestroyFn(tpu_profiler_);
|
||||
}
|
||||
|
||||
Status TpuTracer::Start() {
|
||||
StatusHelper status;
|
||||
tpu::OpsApiFn()->TpuProfiler_StartFn(tpu_profiler_, status.c_status);
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << "Run Start failed.";
|
||||
LOG(ERROR) << "TPU tracer failed to start.";
|
||||
return status.status();
|
||||
}
|
||||
return Status::OK();
|
||||
@ -77,7 +83,7 @@ Status TpuTracer::Stop() {
|
||||
StatusHelper status;
|
||||
tpu::OpsApiFn()->TpuProfiler_StopFn(tpu_profiler_, status.c_status);
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << "Run Stop failed.";
|
||||
LOG(ERROR) << "TPU tracer failed to stop.";
|
||||
return status.status();
|
||||
}
|
||||
return Status::OK();
|
||||
@ -90,10 +96,21 @@ Status TpuTracer::CollectData(RunMetadata* run_metadata) {
|
||||
|
||||
Status TpuTracer::CollectData(XSpace* space) {
|
||||
StatusHelper status;
|
||||
// Get size of buffer required for TPU driver to serialize XSpace into.
|
||||
size_t size_in_bytes;
|
||||
tpu::OpsApiFn()->TpuProfiler_CollectDataFn(tpu_profiler_, status.c_status,
|
||||
space);
|
||||
/*buffer=*/nullptr,
|
||||
&size_in_bytes);
|
||||
// Prepare an appropriately sized buffer.
|
||||
if (size_in_bytes > 0) {
|
||||
std::vector<uint8_t> buffer(size_in_bytes);
|
||||
tpu::OpsApiFn()->TpuProfiler_CollectDataFn(tpu_profiler_, status.c_status,
|
||||
buffer.data(), &size_in_bytes);
|
||||
// Deserialize XSpace from the buffer and return it.
|
||||
space->ParseFromArray(buffer.data(), buffer.size());
|
||||
}
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << "Run CollectData failed.";
|
||||
LOG(ERROR) << "TPU tracer failed to collect data.";
|
||||
return status.status();
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -23,7 +23,7 @@ TfTpu_BaseFn* InitializeApiFn() {
|
||||
return &base_fn;
|
||||
}
|
||||
|
||||
TfTpu_OpsApiFn* OpsApiFn() {
|
||||
const TfTpu_OpsApiFn* OpsApiFn() {
|
||||
static TfTpu_OpsApiFn ops_api_fn;
|
||||
return &ops_api_fn;
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ namespace tpu {
|
||||
|
||||
TfTpu_BaseFn* InitializeApiFn();
|
||||
|
||||
TfTpu_OpsApiFn* OpsApiFn();
|
||||
const TfTpu_OpsApiFn* OpsApiFn();
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -7,7 +7,9 @@
|
||||
namespace {
|
||||
|
||||
tensorflow::Status SetTpuOpsStructFns(void* library_handle) {
|
||||
auto* ops_api_fn = tensorflow::tpu::OpsApiFn();
|
||||
// Constant cast so that we can initialize the functions. The functions are
|
||||
// mutable here because this is the only place where they are initialized.
|
||||
auto* ops_api_fn = const_cast<TfTpu_OpsApiFn*>(tensorflow::tpu::OpsApiFn());
|
||||
|
||||
TFTPU_SET_FN(ops_api_fn, ConfigureDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(ops_api_fn, WaitForDistributedTpuOp_DoWork);
|
||||
@ -70,9 +72,9 @@ tensorflow::Status SetTpuOpsStructFns(void* library_handle) {
|
||||
TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateCompilationCacheKey);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuCompile_DestroyCompilationCacheKey);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateGuaranteedConstFingerprint);
|
||||
|
||||
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProfiler_Create);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProfiler_Free);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProfiler_Destroy);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProfiler_Start);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProfiler_Stop);
|
||||
TFTPU_SET_FN(ops_api_fn, TpuProfiler_CollectData);
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
@ -106,20 +105,40 @@ TFTPU_CAPI_EXPORT void TpuCompile_XrtCompileAndBuild(
|
||||
TpuSerializedProto xrt_computation, const XLA_TpuMeshState* mesh_state,
|
||||
XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status);
|
||||
|
||||
// Creates a new TPU profiler object.
|
||||
TFTPU_CAPI_EXPORT TpuProfiler* TpuProfiler_Create();
|
||||
|
||||
TFTPU_CAPI_EXPORT TpuProfiler* TpuProfiler_Free(TpuProfiler* tpu_profiler);
|
||||
|
||||
// Creates a TPU profiler that is ready to start profiling.
|
||||
TFTPU_CAPI_EXPORT void TpuProfiler_Create(TpuProfiler** tpu_profiler,
|
||||
TF_Status* status);
|
||||
// Destroys the given TPU profiler.
|
||||
TFTPU_CAPI_EXPORT void TpuProfiler_Destroy(TpuProfiler* tpu_profiler);
|
||||
// Starts profiling if not already started, returns an error otherwise.
|
||||
TFTPU_CAPI_EXPORT void TpuProfiler_Start(TpuProfiler* tpu_profiler,
|
||||
TF_Status* status);
|
||||
|
||||
// Stops profiling if not already stopped, returns an error otherwise.
|
||||
TFTPU_CAPI_EXPORT void TpuProfiler_Stop(TpuProfiler* tpu_profiler,
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(
|
||||
TpuProfiler* tpu_profiler, TF_Status* status,
|
||||
tensorflow::profiler::XSpace* space);
|
||||
// Serializes profiled data into `buffer` and returns the size of `buffer`. The
|
||||
// profile data held by the TPU driver will be cleared after retrieval.
|
||||
//
|
||||
// Step 1. Query the size of buffer required into `size_in_bytes`.
|
||||
//
|
||||
// size_t size_in_bytes;
|
||||
// TpuProfiler_CollectData(profiler, status, nullptr, &size_in_bytes);
|
||||
//
|
||||
// Step 2. Retrieve the data into a `buffer` of size `size_in_bytes`.
|
||||
// Subsequently,The TPU driver clears its copy of the profile data.
|
||||
//
|
||||
// uint8_t buffer = new uint8_t[size_in_bytes];
|
||||
// TpuProfiler_CollectData(profiler, status, buffer, size_in_bytes);
|
||||
//
|
||||
// Step 3. Unpack the data into an XSpace.
|
||||
//
|
||||
// tensorflow::profiler::XSpace space;
|
||||
// space.ParseFromArray(buffer, size_in_bytes);
|
||||
//
|
||||
TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(TpuProfiler* tpu_profiler,
|
||||
TF_Status* status,
|
||||
uint8_t* buffer,
|
||||
size_t* size_in_bytes);
|
||||
|
||||
// Creates a new TPU mesh state object.
|
||||
TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
|
||||
@ -416,7 +435,7 @@ struct TfTpu_OpsApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Create);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Destroy);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Start);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Stop);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_CollectData);
|
||||
|
Loading…
Reference in New Issue
Block a user