From d68ead1718f34ba205edcb931a7a1865c9b81201 Mon Sep 17 00:00:00 2001 From: Mingsheng Hong Date: Mon, 11 Jan 2021 18:24:45 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 351277988 Change-Id: I993d8aa6cbee200bd673bccf4039798c7e50c603 --- tensorflow/c/eager/c_api_experimental.cc | 20 ++++++++++ tensorflow/c/eager/c_api_experimental.h | 7 ++++ tensorflow/c/eager/c_api_test.cc | 38 ++++++++++++++++++- .../c/eager/immediate_execution_context.h | 2 + 4 files changed, 65 insertions(+), 2 deletions(-) diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index b8a3d3ee09e..90e9cdc162d 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -649,3 +649,23 @@ int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) { } return tensorflow::unwrap(h)->DeviceId(&status->status); } + +void TFE_GetExecutedOpNames(TFE_Context* ctx, TF_Buffer* buf, + TF_Status* status) { + const std::vector& op_names = + tensorflow::unwrap(ctx)->GetLoggedOpsTestonly(); + + std::ostringstream op_names_oss; + for (const auto& op : op_names) { + op_names_oss << op << ", "; + } + const std::string& op_names_str = op_names_oss.str(); + void* data = tensorflow::port::Malloc(op_names_str.length()); + op_names_str.copy(static_cast(data), op_names_str.length(), 0); + buf->data = data; + buf->length = op_names_str.length(); + buf->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; + status->status = tensorflow::Status::OK(); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 54327199135..30044244acf 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -557,6 +557,13 @@ TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType( TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status); +// Get a comma-separated list of op names executed in graph functions dispatched +// to `ctx`. This feature is currently only enabled for TFRT debug builds, for +// performance and simplicity reasons. +TF_CAPI_EXPORT extern void TFE_GetExecutedOpNames(TFE_Context* ctx, + TF_Buffer* buf, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 9a0cb212083..3037669ac9d 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1041,9 +1041,10 @@ void FunctionDefAndExecute(bool async) { TEST(CAPI, FunctionDefAndExecute) { FunctionDefAndExecute(false); } TEST(CAPI, FunctionDefAndExecuteAsync) { FunctionDefAndExecute(true); } -void RunAddFunction(bool enable_grappler) { +void RunAddFunction(bool use_tfrt, bool enable_grappler) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, use_tfrt); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -1075,6 +1076,12 @@ void RunAddFunction(bool enable_grappler) { serialized_config.length()); } + if (use_tfrt) { + // Set some test-only graph compiler options. + TFE_OpSetAttrBool(op, "TFRT_TEST_enable_native_ops", false); + TFE_OpSetAttrBool(op, "TFRT_TEST_enable_grappler", enable_grappler); + } + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_OpAddInput(op, m, status); @@ -1096,6 +1103,23 @@ void RunAddFunction(bool enable_grappler) { EXPECT_EQ(6, product[2]); EXPECT_EQ(8, product[3]); + // When we turn on grappler, confirm that the tf.Add has been rewritten into a + // tf.Mul. + // This capability of checking the executed op names is currently only enabled + // for TFRT debug build, for performance and simplicity reasons. + if (use_tfrt) { + TF_Buffer* buf = TF_NewBuffer(); + TFE_GetExecutedOpNames(ctx, buf, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +#ifndef NDEBUG + if (enable_grappler) + EXPECT_NE(strstr(static_cast(buf->data), "tf.Mul"), nullptr); + else + EXPECT_NE(strstr(static_cast(buf->data), "tf.Add"), nullptr); +#endif + TF_DeleteBuffer(buf); + } + TFE_ContextRemoveFunction(ctx, "AddFunction", status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TFE_DeleteContext(ctx); @@ -1104,9 +1128,19 @@ void RunAddFunction(bool enable_grappler) { } TEST(CAPI, RunAddFunctionWithGrappler) { - RunAddFunction(/*enable_grappler=*/true); + RunAddFunction(/*use_tfrt=*/false, /*enable_grappler=*/true); } +#ifdef PLATFORM_GOOGLE +TEST(CAPI, RunAddFunction_TFRT) { + RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/false); +} + +TEST(CAPI, RunAddFunctionWithGrappler_TFRT) { + RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/true); +} +#endif + void BM_ExecuteFunction(int iters, int async) { tensorflow::testing::StopTiming(); tensorflow::testing::SetLabel(async ? "ExecuteFunctionAsync" diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index e557753c49d..065534421f5 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -183,6 +183,8 @@ class ImmediateExecutionContext : public AbstractContext { virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface( ImmediateExecutionTensorHandle* handle) = 0; + virtual std::vector GetLoggedOpsTestonly() { return {}; } + //===--------------------------------------------------------------------===// // Distributed runtime related functions. //===--------------------------------------------------------------------===//