diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 6cc74cfb324..a8325ce494c 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -9053,9 +9053,9 @@ static TF_Output getOrCreateSymbolicTensor(TFE_TraceContext* trace_ctx, return ret; } -void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, - TFE_TensorHandle** retvals, int* num_retvals, - TF_Status* status) { +TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, + TFE_TensorHandle** retvals, + int* num_retvals, TF_Status* status) { VLOG(1) << "Calling TFE_AddEagerOpToGraph() with op " << op << ": " << op->operation.DebugString(); @@ -9066,15 +9066,19 @@ void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()); for (auto* input : op->operation.Inputs()) { auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, input, status); - if (!status->status.ok()) return; + if (!status->status.ok()) return nullptr; TF_AddInput(desc, symbolic_input); } VLOG(1) << "Adding attrs."; - // TODO(hongm): add attrs + tensorflow::AttrValueMap attrs; + op->operation.Attrs().FillAttrValueMap(&attrs); + for (const auto& attr : attrs) { + desc->node_builder.Attr(attr.first, attr.second); + } auto* graph_op = TF_FinishOperation(desc, status); - if (!status->status.ok()) return; + if (!status->status.ok()) return nullptr; VLOG(1) << "Op finalized; setting return tensors."; *num_retvals = TF_OperationNumOutputs(graph_op); @@ -9084,6 +9088,7 @@ void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, auto dtype = TF_OperationOutputType(output); retvals[i] = TFE_NewTensorHandleFromTFOutput(output, dtype); } + return graph_op; } int TFE_FinalizeInputTensorsFromTraceContext(TFE_TraceContext* trace_ctx) { diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 48ea0ec1ed7..8d1a8b82fba 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -294,12 +294,11 @@ TF_CAPI_EXPORT extern void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx); // Symbolically executes `op`, by adding a corresponding node to the graph // associated with `trace_ctx`. This graph node outputs a set of symbolic -// tensors in `retvals` and `num_retvals`. -TF_CAPI_EXPORT extern void TFE_AddEagerOpToGraph(TFE_Op* op, - TFE_TraceContext* trace_ctx, - TFE_TensorHandle** retvals, - int* num_retvals, - TF_Status* status); +// tensors in `retvals` and `num_retvals`. Returns the corresponding graph +// operation on success, otherwise returns nullptr. +TF_CAPI_EXPORT extern TF_Operation* TFE_AddEagerOpToGraph( + TFE_Op* op, TFE_TraceContext* trace_ctx, TFE_TensorHandle** retvals, + int* num_retvals, TF_Status* status); // Finalizes the trace graph and its inputs, and returns the number of inputs. // After this call, the next two APIs can be called to iterate over the input diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 4cfcf2ef3b2..354ee5f49f3 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -319,50 +319,131 @@ TEST(CAPI_EXPERIMENTAL, SymbolicTensor) { TF_DeleteStatus(status); } -TEST(CAPI_EXPERIMENTAL, DebugPrintAndSymbolicExecution) { - TF_Status* status = TF_NewStatus(); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_Context* ctx = TFE_NewContext(opts, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContextOptions(opts); +class AddEagerOpToGraphTest : public ::testing::Test { + protected: + AddEagerOpToGraphTest() + : status_(TF_NewStatus()), + eager_ctx_(nullptr), + graph_(TF_NewGraph()), + trace_ctx_(TFE_NewTraceContext(graph_)) { + TFE_ContextOptions* opts = TFE_NewContextOptions(); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + eager_ctx_ = TFE_NewContext(opts, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_DeleteContextOptions(opts); + } + ~AddEagerOpToGraphTest() override { + TFE_DeleteTraceContext(trace_ctx_); + TF_DeleteGraph(graph_); + TFE_DeleteContext(eager_ctx_); + TF_DeleteStatus(status_); + } + + template + void AddEagerOpToGraphAndCheck(TFE_Op* op, Callable checker) { + TFE_TensorHandle* retvals[5]; + int num_retvals = 5; + // Symbolically execute this op, which adds a graph node to `trace_ctx_`. + TF_Operation* graph_op = + TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_NOTNULL(graph_op); + // Check the expectations. + checker(graph_op); + for (int i = 0; i < num_retvals; ++i) { + TFE_DeleteTensorHandle(retvals[i]); + } + } + + TF_Status* status_; + TFE_Context* eager_ctx_; + TF_Graph* graph_; + TFE_TraceContext* trace_ctx_; +}; + +TEST_F(AddEagerOpToGraphTest, DebugPrintAndSymbolicExecution) { TFE_TensorHandle* m = TestMatrixTensorHandle(); - TFE_Op* op = MatMulOp(ctx, m, m); + TFE_Op* op = MatMulOp(eager_ctx_, m, m); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); TFE_OpPrintDebugString(op); - auto* graph = TF_NewGraph(); - auto* trace_ctx = TFE_NewTraceContext(graph); TFE_TensorHandle* retvals[5]; int num_retvals = 5; // Symbolically execute this op, which adds a graph node to `trace_ctx`. - TFE_AddEagerOpToGraph(op, trace_ctx, retvals, &num_retvals, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); - int num_inputs = TFE_FinalizeInputTensorsFromTraceContext(trace_ctx); + int num_inputs = TFE_FinalizeInputTensorsFromTraceContext(trace_ctx_); CHECK_EQ(num_inputs, 1); - auto input_sym_tensor = TFE_GetInputGraphNodeFromTraceContext(trace_ctx, + auto input_sym_tensor = TFE_GetInputGraphNodeFromTraceContext(trace_ctx_, /*idx*/ 0); LOG(INFO) << tensorflow::getTF_OutputDebugString(input_sym_tensor); - auto handle = TFE_ConsumeInputConcreteTensorFromTraceContext(trace_ctx, + auto handle = TFE_ConsumeInputConcreteTensorFromTraceContext(trace_ctx_, /*idx*/ 0); TFE_TensorHandlePrintDebugString(handle); TFE_DeleteTensorHandle(handle); CHECK_EQ(num_retvals, 1); CHECK_EQ(TFE_TensorHandleDataType(retvals[0]), TF_FLOAT); + TFE_DeleteTensorHandle(retvals[0]); - - TFE_DeleteTraceContext(trace_ctx); - TF_DeleteGraph(graph); - TFE_DeleteTensorHandle(m); TFE_DeleteOp(op); - TFE_DeleteContext(ctx); - TF_DeleteStatus(status); +} + +TEST_F(AddEagerOpToGraphTest, ValueAttributesArePreserved) { + // Create MinOp + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* op = MinOp(eager_ctx_, axis, axis); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + + // Check the attributes set by the call to MinOp above. + AddEagerOpToGraphAndCheck(op, [this, &axis](TF_Operation* graph_op) { + unsigned char value; + TF_OperationGetAttrBool(graph_op, "keep_dims", &value, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(value, 1); + TF_DataType dtype; + TF_OperationGetAttrType(graph_op, "Tidx", &dtype, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(dtype, TF_INT32); + TF_OperationGetAttrType(graph_op, "T", &dtype, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(dtype, TFE_TensorHandleDataType(axis)); + }); + TFE_DeleteTensorHandle(axis); + TFE_DeleteOp(op); +} + +TEST_F(AddEagerOpToGraphTest, ListAttributesArePreserved) { + // Create a "Squeeze" operator with list attributes. + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* squeeze = TFE_NewOp(eager_ctx_, "Squeeze", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpAddInput(squeeze, axis, status_); + TFE_OpSetAttrType(squeeze, "T", TF_INT32); + std::vector boundaries = {1, 2, 3, 4}; + TFE_OpSetAttrIntList(squeeze, "squeeze_dims", boundaries.data(), + boundaries.size()); + // Check attributes are preserved. + AddEagerOpToGraphAndCheck( + squeeze, [this, &boundaries](TF_Operation* squeeze_graph_op) { + TF_DataType dtype; + TF_OperationGetAttrType(squeeze_graph_op, "T", &dtype, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(dtype, TF_INT32); + std::unique_ptr list(new int64_t[boundaries.size()]); + TF_OperationGetAttrIntList(squeeze_graph_op, "squeeze_dims", list.get(), + boundaries.size(), status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + EXPECT_TRUE(std::equal(list.get(), list.get() + boundaries.size(), + boundaries.begin())); + }); + TFE_DeleteTensorHandle(axis); + TFE_DeleteOp(squeeze); } } // namespace