Add support for attributes in TFE_AddEagerOpToGraph API.
PiperOrigin-RevId: 231629067
This commit is contained in:
parent
34b64c3af3
commit
95f8fd4f30
@ -9053,9 +9053,9 @@ static TF_Output getOrCreateSymbolicTensor(TFE_TraceContext* trace_ctx,
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||||
TFE_TensorHandle** retvals, int* num_retvals,
|
TFE_TensorHandle** retvals,
|
||||||
TF_Status* status) {
|
int* num_retvals, TF_Status* status) {
|
||||||
VLOG(1) << "Calling TFE_AddEagerOpToGraph() with op " << op << ": "
|
VLOG(1) << "Calling TFE_AddEagerOpToGraph() with op " << op << ": "
|
||||||
<< op->operation.DebugString();
|
<< op->operation.DebugString();
|
||||||
|
|
||||||
@ -9066,15 +9066,19 @@ void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
|||||||
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
|
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
|
||||||
for (auto* input : op->operation.Inputs()) {
|
for (auto* input : op->operation.Inputs()) {
|
||||||
auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, input, status);
|
auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, input, status);
|
||||||
if (!status->status.ok()) return;
|
if (!status->status.ok()) return nullptr;
|
||||||
TF_AddInput(desc, symbolic_input);
|
TF_AddInput(desc, symbolic_input);
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(1) << "Adding attrs.";
|
VLOG(1) << "Adding attrs.";
|
||||||
// TODO(hongm): add attrs
|
tensorflow::AttrValueMap attrs;
|
||||||
|
op->operation.Attrs().FillAttrValueMap(&attrs);
|
||||||
|
for (const auto& attr : attrs) {
|
||||||
|
desc->node_builder.Attr(attr.first, attr.second);
|
||||||
|
}
|
||||||
|
|
||||||
auto* graph_op = TF_FinishOperation(desc, status);
|
auto* graph_op = TF_FinishOperation(desc, status);
|
||||||
if (!status->status.ok()) return;
|
if (!status->status.ok()) return nullptr;
|
||||||
|
|
||||||
VLOG(1) << "Op finalized; setting return tensors.";
|
VLOG(1) << "Op finalized; setting return tensors.";
|
||||||
*num_retvals = TF_OperationNumOutputs(graph_op);
|
*num_retvals = TF_OperationNumOutputs(graph_op);
|
||||||
@ -9084,6 +9088,7 @@ void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
|||||||
auto dtype = TF_OperationOutputType(output);
|
auto dtype = TF_OperationOutputType(output);
|
||||||
retvals[i] = TFE_NewTensorHandleFromTFOutput(output, dtype);
|
retvals[i] = TFE_NewTensorHandleFromTFOutput(output, dtype);
|
||||||
}
|
}
|
||||||
|
return graph_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
int TFE_FinalizeInputTensorsFromTraceContext(TFE_TraceContext* trace_ctx) {
|
int TFE_FinalizeInputTensorsFromTraceContext(TFE_TraceContext* trace_ctx) {
|
||||||
|
@ -294,12 +294,11 @@ TF_CAPI_EXPORT extern void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx);
|
|||||||
|
|
||||||
// Symbolically executes `op`, by adding a corresponding node to the graph
|
// Symbolically executes `op`, by adding a corresponding node to the graph
|
||||||
// associated with `trace_ctx`. This graph node outputs a set of symbolic
|
// associated with `trace_ctx`. This graph node outputs a set of symbolic
|
||||||
// tensors in `retvals` and `num_retvals`.
|
// tensors in `retvals` and `num_retvals`. Returns the corresponding graph
|
||||||
TF_CAPI_EXPORT extern void TFE_AddEagerOpToGraph(TFE_Op* op,
|
// operation on success, otherwise returns nullptr.
|
||||||
TFE_TraceContext* trace_ctx,
|
TF_CAPI_EXPORT extern TF_Operation* TFE_AddEagerOpToGraph(
|
||||||
TFE_TensorHandle** retvals,
|
TFE_Op* op, TFE_TraceContext* trace_ctx, TFE_TensorHandle** retvals,
|
||||||
int* num_retvals,
|
int* num_retvals, TF_Status* status);
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
// Finalizes the trace graph and its inputs, and returns the number of inputs.
|
// Finalizes the trace graph and its inputs, and returns the number of inputs.
|
||||||
// After this call, the next two APIs can be called to iterate over the input
|
// After this call, the next two APIs can be called to iterate over the input
|
||||||
|
@ -319,50 +319,131 @@ TEST(CAPI_EXPERIMENTAL, SymbolicTensor) {
|
|||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI_EXPERIMENTAL, DebugPrintAndSymbolicExecution) {
|
class AddEagerOpToGraphTest : public ::testing::Test {
|
||||||
TF_Status* status = TF_NewStatus();
|
protected:
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
AddEagerOpToGraphTest()
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
: status_(TF_NewStatus()),
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
eager_ctx_(nullptr),
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
graph_(TF_NewGraph()),
|
||||||
TFE_DeleteContextOptions(opts);
|
trace_ctx_(TFE_NewTraceContext(graph_)) {
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
eager_ctx_ = TFE_NewContext(opts, status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
~AddEagerOpToGraphTest() override {
|
||||||
|
TFE_DeleteTraceContext(trace_ctx_);
|
||||||
|
TF_DeleteGraph(graph_);
|
||||||
|
TFE_DeleteContext(eager_ctx_);
|
||||||
|
TF_DeleteStatus(status_);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Callable>
|
||||||
|
void AddEagerOpToGraphAndCheck(TFE_Op* op, Callable checker) {
|
||||||
|
TFE_TensorHandle* retvals[5];
|
||||||
|
int num_retvals = 5;
|
||||||
|
// Symbolically execute this op, which adds a graph node to `trace_ctx_`.
|
||||||
|
TF_Operation* graph_op =
|
||||||
|
TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
CHECK_NOTNULL(graph_op);
|
||||||
|
// Check the expectations.
|
||||||
|
checker(graph_op);
|
||||||
|
for (int i = 0; i < num_retvals; ++i) {
|
||||||
|
TFE_DeleteTensorHandle(retvals[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Status* status_;
|
||||||
|
TFE_Context* eager_ctx_;
|
||||||
|
TF_Graph* graph_;
|
||||||
|
TFE_TraceContext* trace_ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(AddEagerOpToGraphTest, DebugPrintAndSymbolicExecution) {
|
||||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||||
TFE_Op* op = MatMulOp(ctx, m, m);
|
TFE_Op* op = MatMulOp(eager_ctx_, m, m);
|
||||||
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
TFE_OpPrintDebugString(op);
|
TFE_OpPrintDebugString(op);
|
||||||
|
|
||||||
auto* graph = TF_NewGraph();
|
|
||||||
auto* trace_ctx = TFE_NewTraceContext(graph);
|
|
||||||
TFE_TensorHandle* retvals[5];
|
TFE_TensorHandle* retvals[5];
|
||||||
int num_retvals = 5;
|
int num_retvals = 5;
|
||||||
// Symbolically execute this op, which adds a graph node to `trace_ctx`.
|
// Symbolically execute this op, which adds a graph node to `trace_ctx`.
|
||||||
TFE_AddEagerOpToGraph(op, trace_ctx, retvals, &num_retvals, status);
|
TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
|
||||||
int num_inputs = TFE_FinalizeInputTensorsFromTraceContext(trace_ctx);
|
int num_inputs = TFE_FinalizeInputTensorsFromTraceContext(trace_ctx_);
|
||||||
CHECK_EQ(num_inputs, 1);
|
CHECK_EQ(num_inputs, 1);
|
||||||
auto input_sym_tensor = TFE_GetInputGraphNodeFromTraceContext(trace_ctx,
|
auto input_sym_tensor = TFE_GetInputGraphNodeFromTraceContext(trace_ctx_,
|
||||||
/*idx*/ 0);
|
/*idx*/ 0);
|
||||||
|
|
||||||
LOG(INFO) << tensorflow::getTF_OutputDebugString(input_sym_tensor);
|
LOG(INFO) << tensorflow::getTF_OutputDebugString(input_sym_tensor);
|
||||||
auto handle = TFE_ConsumeInputConcreteTensorFromTraceContext(trace_ctx,
|
auto handle = TFE_ConsumeInputConcreteTensorFromTraceContext(trace_ctx_,
|
||||||
/*idx*/ 0);
|
/*idx*/ 0);
|
||||||
TFE_TensorHandlePrintDebugString(handle);
|
TFE_TensorHandlePrintDebugString(handle);
|
||||||
TFE_DeleteTensorHandle(handle);
|
TFE_DeleteTensorHandle(handle);
|
||||||
|
|
||||||
CHECK_EQ(num_retvals, 1);
|
CHECK_EQ(num_retvals, 1);
|
||||||
CHECK_EQ(TFE_TensorHandleDataType(retvals[0]), TF_FLOAT);
|
CHECK_EQ(TFE_TensorHandleDataType(retvals[0]), TF_FLOAT);
|
||||||
|
|
||||||
TFE_DeleteTensorHandle(retvals[0]);
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
|
|
||||||
TFE_DeleteTraceContext(trace_ctx);
|
|
||||||
TF_DeleteGraph(graph);
|
|
||||||
|
|
||||||
TFE_DeleteTensorHandle(m);
|
TFE_DeleteTensorHandle(m);
|
||||||
TFE_DeleteOp(op);
|
TFE_DeleteOp(op);
|
||||||
TFE_DeleteContext(ctx);
|
}
|
||||||
TF_DeleteStatus(status);
|
|
||||||
|
TEST_F(AddEagerOpToGraphTest, ValueAttributesArePreserved) {
|
||||||
|
// Create MinOp
|
||||||
|
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||||
|
TFE_Op* op = MinOp(eager_ctx_, axis, axis);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
|
||||||
|
// Check the attributes set by the call to MinOp above.
|
||||||
|
AddEagerOpToGraphAndCheck(op, [this, &axis](TF_Operation* graph_op) {
|
||||||
|
unsigned char value;
|
||||||
|
TF_OperationGetAttrBool(graph_op, "keep_dims", &value, status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
CHECK_EQ(value, 1);
|
||||||
|
TF_DataType dtype;
|
||||||
|
TF_OperationGetAttrType(graph_op, "Tidx", &dtype, status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
CHECK_EQ(dtype, TF_INT32);
|
||||||
|
TF_OperationGetAttrType(graph_op, "T", &dtype, status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
CHECK_EQ(dtype, TFE_TensorHandleDataType(axis));
|
||||||
|
});
|
||||||
|
TFE_DeleteTensorHandle(axis);
|
||||||
|
TFE_DeleteOp(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AddEagerOpToGraphTest, ListAttributesArePreserved) {
|
||||||
|
// Create a "Squeeze" operator with list attributes.
|
||||||
|
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||||
|
TFE_Op* squeeze = TFE_NewOp(eager_ctx_, "Squeeze", status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
TFE_OpAddInput(squeeze, axis, status_);
|
||||||
|
TFE_OpSetAttrType(squeeze, "T", TF_INT32);
|
||||||
|
std::vector<int64_t> boundaries = {1, 2, 3, 4};
|
||||||
|
TFE_OpSetAttrIntList(squeeze, "squeeze_dims", boundaries.data(),
|
||||||
|
boundaries.size());
|
||||||
|
// Check attributes are preserved.
|
||||||
|
AddEagerOpToGraphAndCheck(
|
||||||
|
squeeze, [this, &boundaries](TF_Operation* squeeze_graph_op) {
|
||||||
|
TF_DataType dtype;
|
||||||
|
TF_OperationGetAttrType(squeeze_graph_op, "T", &dtype, status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
CHECK_EQ(dtype, TF_INT32);
|
||||||
|
std::unique_ptr<int64_t[]> list(new int64_t[boundaries.size()]);
|
||||||
|
TF_OperationGetAttrIntList(squeeze_graph_op, "squeeze_dims", list.get(),
|
||||||
|
boundaries.size(), status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
EXPECT_TRUE(std::equal(list.get(), list.get() + boundaries.size(),
|
||||||
|
boundaries.begin()));
|
||||||
|
});
|
||||||
|
TFE_DeleteTensorHandle(axis);
|
||||||
|
TFE_DeleteOp(squeeze);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user