Internal change
PiperOrigin-RevId: 351277988 Change-Id: I993d8aa6cbee200bd673bccf4039798c7e50c603
This commit is contained in:
parent
4d8f8d252a
commit
d68ead1718
@ -649,3 +649,23 @@ int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
|
||||
}
|
||||
return tensorflow::unwrap(h)->DeviceId(&status->status);
|
||||
}
|
||||
|
||||
void TFE_GetExecutedOpNames(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
const std::vector<std::string>& op_names =
|
||||
tensorflow::unwrap(ctx)->GetLoggedOpsTestonly();
|
||||
|
||||
std::ostringstream op_names_oss;
|
||||
for (const auto& op : op_names) {
|
||||
op_names_oss << op << ", ";
|
||||
}
|
||||
const std::string& op_names_str = op_names_oss.str();
|
||||
void* data = tensorflow::port::Malloc(op_names_str.length());
|
||||
op_names_str.copy(static_cast<char*>(data), op_names_str.length(), 0);
|
||||
buf->data = data;
|
||||
buf->length = op_names_str.length();
|
||||
buf->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
@ -557,6 +557,13 @@ TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
|
||||
TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
// Get a comma-separated list of op names executed in graph functions dispatched
|
||||
// to `ctx`. This feature is currently only enabled for TFRT debug builds, for
|
||||
// performance and simplicity reasons.
|
||||
TF_CAPI_EXPORT extern void TFE_GetExecutedOpNames(TFE_Context* ctx,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -1041,9 +1041,10 @@ void FunctionDefAndExecute(bool async) {
|
||||
TEST(CAPI, FunctionDefAndExecute) { FunctionDefAndExecute(false); }
|
||||
TEST(CAPI, FunctionDefAndExecuteAsync) { FunctionDefAndExecute(true); }
|
||||
|
||||
void RunAddFunction(bool enable_grappler) {
|
||||
void RunAddFunction(bool use_tfrt, bool enable_grappler) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
@ -1075,6 +1076,12 @@ void RunAddFunction(bool enable_grappler) {
|
||||
serialized_config.length());
|
||||
}
|
||||
|
||||
if (use_tfrt) {
|
||||
// Set some test-only graph compiler options.
|
||||
TFE_OpSetAttrBool(op, "TFRT_TEST_enable_native_ops", false);
|
||||
TFE_OpSetAttrBool(op, "TFRT_TEST_enable_grappler", enable_grappler);
|
||||
}
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_OpAddInput(op, m, status);
|
||||
@ -1096,6 +1103,23 @@ void RunAddFunction(bool enable_grappler) {
|
||||
EXPECT_EQ(6, product[2]);
|
||||
EXPECT_EQ(8, product[3]);
|
||||
|
||||
// When we turn on grappler, confirm that the tf.Add has been rewritten into a
|
||||
// tf.Mul.
|
||||
// This capability of checking the executed op names is currently only enabled
|
||||
// for TFRT debug build, for performance and simplicity reasons.
|
||||
if (use_tfrt) {
|
||||
TF_Buffer* buf = TF_NewBuffer();
|
||||
TFE_GetExecutedOpNames(ctx, buf, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
#ifndef NDEBUG
|
||||
if (enable_grappler)
|
||||
EXPECT_NE(strstr(static_cast<const char*>(buf->data), "tf.Mul"), nullptr);
|
||||
else
|
||||
EXPECT_NE(strstr(static_cast<const char*>(buf->data), "tf.Add"), nullptr);
|
||||
#endif
|
||||
TF_DeleteBuffer(buf);
|
||||
}
|
||||
|
||||
TFE_ContextRemoveFunction(ctx, "AddFunction", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
@ -1104,9 +1128,19 @@ void RunAddFunction(bool enable_grappler) {
|
||||
}
|
||||
|
||||
TEST(CAPI, RunAddFunctionWithGrappler) {
|
||||
RunAddFunction(/*enable_grappler=*/true);
|
||||
RunAddFunction(/*use_tfrt=*/false, /*enable_grappler=*/true);
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
TEST(CAPI, RunAddFunction_TFRT) {
|
||||
RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/false);
|
||||
}
|
||||
|
||||
TEST(CAPI, RunAddFunctionWithGrappler_TFRT) {
|
||||
RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/true);
|
||||
}
|
||||
#endif
|
||||
|
||||
void BM_ExecuteFunction(int iters, int async) {
|
||||
tensorflow::testing::StopTiming();
|
||||
tensorflow::testing::SetLabel(async ? "ExecuteFunctionAsync"
|
||||
|
@ -183,6 +183,8 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
|
||||
ImmediateExecutionTensorHandle* handle) = 0;
|
||||
|
||||
virtual std::vector<std::string> GetLoggedOpsTestonly() { return {}; }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Distributed runtime related functions.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
Loading…
Reference in New Issue
Block a user