From b70b22739cc41c004e9ff6a0ad6b599b308d47dd Mon Sep 17 00:00:00 2001
From: Yi Situ <yisitu@google.com>
Date: Tue, 2 Feb 2021 18:48:04 -0800
Subject: [PATCH] Make profiler C APIs portable by serializing opaque protobufs
 into a user provided buffer.

WHAT
* Make profiler APIs in libtpu portable.
* CollectData() takes in a user sized buffer that is big enough to contain a serialized XSpace from the TPU driver.

WHY
* Minimize number of APIs to maintain.
* Minimize number of serialize and deserializations of XSpace.
* Eliminate incompatibilities and crashes as a result of passing protobufs at shared library boundaries.
* Untangle ownership and lifetime of resources; buffer for serializing is used and owned by the client, collected XSpace is owned by driver until after serialization.

Misc clean ups:
* Fix mismatch of TpuProfiler_Free definition vs declaration.
* Fixed a leak in TpuProfiler_Create() on error conditions.

Note: TPU driver refers to libtpu.
PiperOrigin-RevId: 355299692
Change-Id: Ie37c295c20c29bb511e9e969d785de57fe1446fd
---
 .../core/profiler/internal/tpu/tpu_tracer.cc  | 29 ++++++++++---
 tensorflow/core/tpu/tpu_api.cc                |  2 +-
 tensorflow/core/tpu/tpu_api.h                 |  2 +-
 tensorflow/core/tpu/tpu_library_init_fns.inc  |  8 ++--
 tensorflow/core/tpu/tpu_ops_c_api.h           | 43 +++++++++++++------
 5 files changed, 61 insertions(+), 23 deletions(-)

diff --git a/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc b/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
index a156c1713f3..dbe32ee043c 100644
--- a/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
+++ b/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
@@ -58,16 +58,22 @@ class TpuTracer : public ProfilerInterface {
 };
 
 TpuTracer::TpuTracer() {
-  tpu_profiler_ = tpu::OpsApiFn()->TpuProfiler_CreateFn();
+  StatusHelper status;
+  tpu::OpsApiFn()->TpuProfiler_CreateFn(&tpu_profiler_, status.c_status);
+  if (!status.ok()) {
+    LOG(ERROR) << status.status().error_message();
+  }
 }
 
-TpuTracer::~TpuTracer() { tpu::OpsApiFn()->TpuProfiler_FreeFn(tpu_profiler_); }
+TpuTracer::~TpuTracer() {
+  tpu::OpsApiFn()->TpuProfiler_DestroyFn(tpu_profiler_);
+}
 
 Status TpuTracer::Start() {
   StatusHelper status;
   tpu::OpsApiFn()->TpuProfiler_StartFn(tpu_profiler_, status.c_status);
   if (!status.ok()) {
-    VLOG(1) << "Run Start failed.";
+    LOG(ERROR) << "TPU tracer failed to start.";
     return status.status();
   }
   return Status::OK();
@@ -77,7 +83,7 @@ Status TpuTracer::Stop() {
   StatusHelper status;
   tpu::OpsApiFn()->TpuProfiler_StopFn(tpu_profiler_, status.c_status);
   if (!status.ok()) {
-    VLOG(1) << "Run Stop failed.";
+    LOG(ERROR) << "TPU tracer failed to stop.";
     return status.status();
   }
   return Status::OK();
@@ -90,10 +96,21 @@ Status TpuTracer::CollectData(RunMetadata* run_metadata) {
 
 Status TpuTracer::CollectData(XSpace* space) {
   StatusHelper status;
+  // Get size of buffer required for TPU driver to serialize XSpace into.
+  size_t size_in_bytes;
   tpu::OpsApiFn()->TpuProfiler_CollectDataFn(tpu_profiler_, status.c_status,
-                                             space);
+                                             /*buffer=*/nullptr,
+                                             &size_in_bytes);
+  // Prepare an appropriately sized buffer.
+  if (size_in_bytes > 0) {
+    std::vector<uint8_t> buffer(size_in_bytes);
+    tpu::OpsApiFn()->TpuProfiler_CollectDataFn(tpu_profiler_, status.c_status,
+                                               buffer.data(), &size_in_bytes);
+    // Deserialize XSpace from the buffer and return it.
+    space->ParseFromArray(buffer.data(), buffer.size());
+  }
   if (!status.ok()) {
-    VLOG(1) << "Run CollectData failed.";
+    LOG(ERROR) << "TPU tracer failed to collect data.";
     return status.status();
   }
   return Status::OK();
diff --git a/tensorflow/core/tpu/tpu_api.cc b/tensorflow/core/tpu/tpu_api.cc
index 690e2049652..339e8ef4d83 100644
--- a/tensorflow/core/tpu/tpu_api.cc
+++ b/tensorflow/core/tpu/tpu_api.cc
@@ -23,7 +23,7 @@ TfTpu_BaseFn* InitializeApiFn() {
   return &base_fn;
 }
 
-TfTpu_OpsApiFn* OpsApiFn() {
+const TfTpu_OpsApiFn* OpsApiFn() {
   static TfTpu_OpsApiFn ops_api_fn;
   return &ops_api_fn;
 }
diff --git a/tensorflow/core/tpu/tpu_api.h b/tensorflow/core/tpu/tpu_api.h
index b880f4ed9cf..45ada404275 100644
--- a/tensorflow/core/tpu/tpu_api.h
+++ b/tensorflow/core/tpu/tpu_api.h
@@ -25,7 +25,7 @@ namespace tpu {
 
 TfTpu_BaseFn* InitializeApiFn();
 
-TfTpu_OpsApiFn* OpsApiFn();
+const TfTpu_OpsApiFn* OpsApiFn();
 
 }  // namespace tpu
 }  // namespace tensorflow
diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc
index 0b984fa2a75..340077f054c 100644
--- a/tensorflow/core/tpu/tpu_library_init_fns.inc
+++ b/tensorflow/core/tpu/tpu_library_init_fns.inc
@@ -7,7 +7,9 @@
 namespace {
 
 tensorflow::Status SetTpuOpsStructFns(void* library_handle) {
-  auto* ops_api_fn = tensorflow::tpu::OpsApiFn();
+  // Constant cast so that we can initialize the functions. The functions are
+  // mutable here because this is the only place where they are initialized.
+  auto* ops_api_fn = const_cast<TfTpu_OpsApiFn*>(tensorflow::tpu::OpsApiFn());
 
   TFTPU_SET_FN(ops_api_fn, ConfigureDistributedTpuOp_DoWork);
   TFTPU_SET_FN(ops_api_fn, WaitForDistributedTpuOp_DoWork);
@@ -70,9 +72,9 @@ tensorflow::Status SetTpuOpsStructFns(void* library_handle) {
   TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateCompilationCacheKey);
   TFTPU_SET_FN(ops_api_fn, TpuCompile_DestroyCompilationCacheKey);
   TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateGuaranteedConstFingerprint);
-  
+
   TFTPU_SET_FN(ops_api_fn, TpuProfiler_Create);
-  TFTPU_SET_FN(ops_api_fn, TpuProfiler_Free);
+  TFTPU_SET_FN(ops_api_fn, TpuProfiler_Destroy);
   TFTPU_SET_FN(ops_api_fn, TpuProfiler_Start);
   TFTPU_SET_FN(ops_api_fn, TpuProfiler_Stop);
   TFTPU_SET_FN(ops_api_fn, TpuProfiler_CollectData);
diff --git a/tensorflow/core/tpu/tpu_ops_c_api.h b/tensorflow/core/tpu/tpu_ops_c_api.h
index f361110f975..a84579c93f6 100644
--- a/tensorflow/core/tpu/tpu_ops_c_api.h
+++ b/tensorflow/core/tpu/tpu_ops_c_api.h
@@ -19,7 +19,6 @@ limitations under the License.
 
 #include <cstdint>
 
-#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/tpu/libtftpu.h"
 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
 #include "tensorflow/stream_executor/tpu/proto_helper.h"
@@ -106,20 +105,40 @@ TFTPU_CAPI_EXPORT void TpuCompile_XrtCompileAndBuild(
     TpuSerializedProto xrt_computation, const XLA_TpuMeshState* mesh_state,
     XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status);
 
-// Creates a new TPU profiler object.
-TFTPU_CAPI_EXPORT TpuProfiler* TpuProfiler_Create();
-
-TFTPU_CAPI_EXPORT TpuProfiler* TpuProfiler_Free(TpuProfiler* tpu_profiler);
-
+// Creates a TPU profiler that is ready to start profiling.
+TFTPU_CAPI_EXPORT void TpuProfiler_Create(TpuProfiler** tpu_profiler,
+                                          TF_Status* status);
+// Destroys the given TPU profiler.
+TFTPU_CAPI_EXPORT void TpuProfiler_Destroy(TpuProfiler* tpu_profiler);
+// Starts profiling if not already started, returns an error otherwise.
 TFTPU_CAPI_EXPORT void TpuProfiler_Start(TpuProfiler* tpu_profiler,
                                          TF_Status* status);
-
+// Stops profiling if not already stopped, returns an error otherwise.
 TFTPU_CAPI_EXPORT void TpuProfiler_Stop(TpuProfiler* tpu_profiler,
                                         TF_Status* status);
-
-TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(
-    TpuProfiler* tpu_profiler, TF_Status* status,
-    tensorflow::profiler::XSpace* space);
+// Serializes profiled data into `buffer` and returns the size of `buffer`. The
+// profile data held by the TPU driver will be cleared after retrieval.
+//
+// Step 1. Query the size of buffer required into `size_in_bytes`.
+//
+//   size_t size_in_bytes;
+//   TpuProfiler_CollectData(profiler, status, nullptr, &size_in_bytes);
+//
+// Step 2. Retrieve the data into a `buffer` of size `size_in_bytes`.
+//         Subsequently,The TPU driver clears its copy of the profile data.
+//
+//   uint8_t buffer = new uint8_t[size_in_bytes];
+//   TpuProfiler_CollectData(profiler, status, buffer, size_in_bytes);
+//
+// Step 3. Unpack the data into an XSpace.
+//
+//   tensorflow::profiler::XSpace space;
+//   space.ParseFromArray(buffer, size_in_bytes);
+//
+TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(TpuProfiler* tpu_profiler,
+                                               TF_Status* status,
+                                               uint8_t* buffer,
+                                               size_t* size_in_bytes);
 
 // Creates a new TPU mesh state object.
 TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
@@ -416,7 +435,7 @@ struct TfTpu_OpsApiFn {
   TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
 
   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Create);
-  TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Free);
+  TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Destroy);
   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Start);
   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Stop);
   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_CollectData);