Internal change

PiperOrigin-RevId: 351277988
Change-Id: I993d8aa6cbee200bd673bccf4039798c7e50c603
This commit is contained in:
Mingsheng Hong 2021-01-11 18:24:45 -08:00 committed by TensorFlower Gardener
parent 4d8f8d252a
commit d68ead1718
4 changed files with 65 additions and 2 deletions

View File

@ -649,3 +649,23 @@ int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
}
return tensorflow::unwrap(h)->DeviceId(&status->status);
}
void TFE_GetExecutedOpNames(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
const std::vector<std::string>& op_names =
tensorflow::unwrap(ctx)->GetLoggedOpsTestonly();
std::ostringstream op_names_oss;
for (const auto& op : op_names) {
op_names_oss << op << ", ";
}
const std::string& op_names_str = op_names_oss.str();
void* data = tensorflow::port::Malloc(op_names_str.length());
op_names_str.copy(static_cast<char*>(data), op_names_str.length(), 0);
buf->data = data;
buf->length = op_names_str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
status->status = tensorflow::Status::OK();
}

View File

@ -557,6 +557,13 @@ TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
TF_Status* status);
// Get a comma-separated list of op names executed in graph functions dispatched
// to `ctx`. This feature is currently only enabled for TFRT debug builds, for
// performance and simplicity reasons.
TF_CAPI_EXPORT extern void TFE_GetExecutedOpNames(TFE_Context* ctx,
TF_Buffer* buf,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -1041,9 +1041,10 @@ void FunctionDefAndExecute(bool async) {
TEST(CAPI, FunctionDefAndExecute) { FunctionDefAndExecute(false); }
TEST(CAPI, FunctionDefAndExecuteAsync) { FunctionDefAndExecute(true); }
void RunAddFunction(bool enable_grappler) {
void RunAddFunction(bool use_tfrt, bool enable_grappler) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
@ -1075,6 +1076,12 @@ void RunAddFunction(bool enable_grappler) {
serialized_config.length());
}
if (use_tfrt) {
// Set some test-only graph compiler options.
TFE_OpSetAttrBool(op, "TFRT_TEST_enable_native_ops", false);
TFE_OpSetAttrBool(op, "TFRT_TEST_enable_grappler", enable_grappler);
}
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, m, status);
@ -1096,6 +1103,23 @@ void RunAddFunction(bool enable_grappler) {
EXPECT_EQ(6, product[2]);
EXPECT_EQ(8, product[3]);
// When we turn on grappler, confirm that the tf.Add has been rewritten into a
// tf.Mul.
// This capability of checking the executed op names is currently only enabled
// for TFRT debug build, for performance and simplicity reasons.
if (use_tfrt) {
TF_Buffer* buf = TF_NewBuffer();
TFE_GetExecutedOpNames(ctx, buf, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
#ifndef NDEBUG
if (enable_grappler)
EXPECT_NE(strstr(static_cast<const char*>(buf->data), "tf.Mul"), nullptr);
else
EXPECT_NE(strstr(static_cast<const char*>(buf->data), "tf.Add"), nullptr);
#endif
TF_DeleteBuffer(buf);
}
TFE_ContextRemoveFunction(ctx, "AddFunction", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx);
@ -1104,9 +1128,19 @@ void RunAddFunction(bool enable_grappler) {
}
TEST(CAPI, RunAddFunctionWithGrappler) {
RunAddFunction(/*enable_grappler=*/true);
RunAddFunction(/*use_tfrt=*/false, /*enable_grappler=*/true);
}
#ifdef PLATFORM_GOOGLE
TEST(CAPI, RunAddFunction_TFRT) {
RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/false);
}
TEST(CAPI, RunAddFunctionWithGrappler_TFRT) {
RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/true);
}
#endif
void BM_ExecuteFunction(int iters, int async) {
tensorflow::testing::StopTiming();
tensorflow::testing::SetLabel(async ? "ExecuteFunctionAsync"

View File

@ -183,6 +183,8 @@ class ImmediateExecutionContext : public AbstractContext {
virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
ImmediateExecutionTensorHandle* handle) = 0;
virtual std::vector<std::string> GetLoggedOpsTestonly() { return {}; }
//===--------------------------------------------------------------------===//
// Distributed runtime related functions.
//===--------------------------------------------------------------------===//