Add support for attributes in TFE_AddEagerOpToGraph API.

PiperOrigin-RevId: 231629067
This commit is contained in:
A. Unique TensorFlower 2019-01-30 11:06:18 -08:00 committed by TensorFlower Gardener
parent 34b64c3af3
commit 95f8fd4f30
3 changed files with 119 additions and 34 deletions

View File

@ -9053,9 +9053,9 @@ static TF_Output getOrCreateSymbolicTensor(TFE_TraceContext* trace_ctx,
return ret;
}
void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
TFE_TensorHandle** retvals,
int* num_retvals, TF_Status* status) {
VLOG(1) << "Calling TFE_AddEagerOpToGraph() with op " << op << ": "
<< op->operation.DebugString();
@ -9066,15 +9066,19 @@ void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
for (auto* input : op->operation.Inputs()) {
auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, input, status);
if (!status->status.ok()) return;
if (!status->status.ok()) return nullptr;
TF_AddInput(desc, symbolic_input);
}
VLOG(1) << "Adding attrs.";
// TODO(hongm): add attrs
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
for (const auto& attr : attrs) {
desc->node_builder.Attr(attr.first, attr.second);
}
auto* graph_op = TF_FinishOperation(desc, status);
if (!status->status.ok()) return;
if (!status->status.ok()) return nullptr;
VLOG(1) << "Op finalized; setting return tensors.";
*num_retvals = TF_OperationNumOutputs(graph_op);
@ -9084,6 +9088,7 @@ void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
auto dtype = TF_OperationOutputType(output);
retvals[i] = TFE_NewTensorHandleFromTFOutput(output, dtype);
}
return graph_op;
}
int TFE_FinalizeInputTensorsFromTraceContext(TFE_TraceContext* trace_ctx) {

View File

@ -294,12 +294,11 @@ TF_CAPI_EXPORT extern void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx);
// Symbolically executes `op`, by adding a corresponding node to the graph
// associated with `trace_ctx`. This graph node outputs a set of symbolic
// tensors in `retvals` and `num_retvals`.
TF_CAPI_EXPORT extern void TFE_AddEagerOpToGraph(TFE_Op* op,
TFE_TraceContext* trace_ctx,
TFE_TensorHandle** retvals,
int* num_retvals,
TF_Status* status);
// tensors in `retvals` and `num_retvals`. Returns the corresponding graph
// operation on success, otherwise returns nullptr.
TF_CAPI_EXPORT extern TF_Operation* TFE_AddEagerOpToGraph(
TFE_Op* op, TFE_TraceContext* trace_ctx, TFE_TensorHandle** retvals,
int* num_retvals, TF_Status* status);
// Finalizes the trace graph and its inputs, and returns the number of inputs.
// After this call, the next two APIs can be called to iterate over the input

View File

@ -319,50 +319,131 @@ TEST(CAPI_EXPERIMENTAL, SymbolicTensor) {
TF_DeleteStatus(status);
}
TEST(CAPI_EXPERIMENTAL, DebugPrintAndSymbolicExecution) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
class AddEagerOpToGraphTest : public ::testing::Test {
protected:
AddEagerOpToGraphTest()
: status_(TF_NewStatus()),
eager_ctx_(nullptr),
graph_(TF_NewGraph()),
trace_ctx_(TFE_NewTraceContext(graph_)) {
TFE_ContextOptions* opts = TFE_NewContextOptions();
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
eager_ctx_ = TFE_NewContext(opts, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_DeleteContextOptions(opts);
}
~AddEagerOpToGraphTest() override {
TFE_DeleteTraceContext(trace_ctx_);
TF_DeleteGraph(graph_);
TFE_DeleteContext(eager_ctx_);
TF_DeleteStatus(status_);
}
template <typename Callable>
void AddEagerOpToGraphAndCheck(TFE_Op* op, Callable checker) {
TFE_TensorHandle* retvals[5];
int num_retvals = 5;
// Symbolically execute this op, which adds a graph node to `trace_ctx_`.
TF_Operation* graph_op =
TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
CHECK_NOTNULL(graph_op);
// Check the expectations.
checker(graph_op);
for (int i = 0; i < num_retvals; ++i) {
TFE_DeleteTensorHandle(retvals[i]);
}
}
TF_Status* status_;
TFE_Context* eager_ctx_;
TF_Graph* graph_;
TFE_TraceContext* trace_ctx_;
};
TEST_F(AddEagerOpToGraphTest, DebugPrintAndSymbolicExecution) {
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* op = MatMulOp(ctx, m, m);
TFE_Op* op = MatMulOp(eager_ctx_, m, m);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpPrintDebugString(op);
auto* graph = TF_NewGraph();
auto* trace_ctx = TFE_NewTraceContext(graph);
TFE_TensorHandle* retvals[5];
int num_retvals = 5;
// Symbolically execute this op, which adds a graph node to `trace_ctx`.
TFE_AddEagerOpToGraph(op, trace_ctx, retvals, &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
int num_inputs = TFE_FinalizeInputTensorsFromTraceContext(trace_ctx);
int num_inputs = TFE_FinalizeInputTensorsFromTraceContext(trace_ctx_);
CHECK_EQ(num_inputs, 1);
auto input_sym_tensor = TFE_GetInputGraphNodeFromTraceContext(trace_ctx,
auto input_sym_tensor = TFE_GetInputGraphNodeFromTraceContext(trace_ctx_,
/*idx*/ 0);
LOG(INFO) << tensorflow::getTF_OutputDebugString(input_sym_tensor);
auto handle = TFE_ConsumeInputConcreteTensorFromTraceContext(trace_ctx,
auto handle = TFE_ConsumeInputConcreteTensorFromTraceContext(trace_ctx_,
/*idx*/ 0);
TFE_TensorHandlePrintDebugString(handle);
TFE_DeleteTensorHandle(handle);
CHECK_EQ(num_retvals, 1);
CHECK_EQ(TFE_TensorHandleDataType(retvals[0]), TF_FLOAT);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTraceContext(trace_ctx);
TF_DeleteGraph(graph);
TFE_DeleteTensorHandle(m);
TFE_DeleteOp(op);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST_F(AddEagerOpToGraphTest, ValueAttributesArePreserved) {
// Create MinOp
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_Op* op = MinOp(eager_ctx_, axis, axis);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
// Check the attributes set by the call to MinOp above.
AddEagerOpToGraphAndCheck(op, [this, &axis](TF_Operation* graph_op) {
unsigned char value;
TF_OperationGetAttrBool(graph_op, "keep_dims", &value, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
CHECK_EQ(value, 1);
TF_DataType dtype;
TF_OperationGetAttrType(graph_op, "Tidx", &dtype, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
CHECK_EQ(dtype, TF_INT32);
TF_OperationGetAttrType(graph_op, "T", &dtype, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
CHECK_EQ(dtype, TFE_TensorHandleDataType(axis));
});
TFE_DeleteTensorHandle(axis);
TFE_DeleteOp(op);
}
TEST_F(AddEagerOpToGraphTest, ListAttributesArePreserved) {
// Create a "Squeeze" operator with list attributes.
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_Op* squeeze = TFE_NewOp(eager_ctx_, "Squeeze", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpAddInput(squeeze, axis, status_);
TFE_OpSetAttrType(squeeze, "T", TF_INT32);
std::vector<int64_t> boundaries = {1, 2, 3, 4};
TFE_OpSetAttrIntList(squeeze, "squeeze_dims", boundaries.data(),
boundaries.size());
// Check attributes are preserved.
AddEagerOpToGraphAndCheck(
squeeze, [this, &boundaries](TF_Operation* squeeze_graph_op) {
TF_DataType dtype;
TF_OperationGetAttrType(squeeze_graph_op, "T", &dtype, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
CHECK_EQ(dtype, TF_INT32);
std::unique_ptr<int64_t[]> list(new int64_t[boundaries.size()]);
TF_OperationGetAttrIntList(squeeze_graph_op, "squeeze_dims", list.get(),
boundaries.size(), status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
EXPECT_TRUE(std::equal(list.get(), list.get() + boundaries.size(),
boundaries.begin()));
});
TFE_DeleteTensorHandle(axis);
TFE_DeleteOp(squeeze);
}
} // namespace