Add support for attributes in TFE_AddEagerOpToGraph API.
PiperOrigin-RevId: 231629067
This commit is contained in:
parent
34b64c3af3
commit
95f8fd4f30
@ -9053,9 +9053,9 @@ static TF_Output getOrCreateSymbolicTensor(TFE_TraceContext* trace_ctx,
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||
TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||
TFE_TensorHandle** retvals,
|
||||
int* num_retvals, TF_Status* status) {
|
||||
VLOG(1) << "Calling TFE_AddEagerOpToGraph() with op " << op << ": "
|
||||
<< op->operation.DebugString();
|
||||
|
||||
@ -9066,15 +9066,19 @@ void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
|
||||
for (auto* input : op->operation.Inputs()) {
|
||||
auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, input, status);
|
||||
if (!status->status.ok()) return;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
TF_AddInput(desc, symbolic_input);
|
||||
}
|
||||
|
||||
VLOG(1) << "Adding attrs.";
|
||||
// TODO(hongm): add attrs
|
||||
tensorflow::AttrValueMap attrs;
|
||||
op->operation.Attrs().FillAttrValueMap(&attrs);
|
||||
for (const auto& attr : attrs) {
|
||||
desc->node_builder.Attr(attr.first, attr.second);
|
||||
}
|
||||
|
||||
auto* graph_op = TF_FinishOperation(desc, status);
|
||||
if (!status->status.ok()) return;
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
VLOG(1) << "Op finalized; setting return tensors.";
|
||||
*num_retvals = TF_OperationNumOutputs(graph_op);
|
||||
@ -9084,6 +9088,7 @@ void TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||
auto dtype = TF_OperationOutputType(output);
|
||||
retvals[i] = TFE_NewTensorHandleFromTFOutput(output, dtype);
|
||||
}
|
||||
return graph_op;
|
||||
}
|
||||
|
||||
int TFE_FinalizeInputTensorsFromTraceContext(TFE_TraceContext* trace_ctx) {
|
||||
|
@ -294,12 +294,11 @@ TF_CAPI_EXPORT extern void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx);
|
||||
|
||||
// Symbolically executes `op`, by adding a corresponding node to the graph
|
||||
// associated with `trace_ctx`. This graph node outputs a set of symbolic
|
||||
// tensors in `retvals` and `num_retvals`.
|
||||
TF_CAPI_EXPORT extern void TFE_AddEagerOpToGraph(TFE_Op* op,
|
||||
TFE_TraceContext* trace_ctx,
|
||||
TFE_TensorHandle** retvals,
|
||||
int* num_retvals,
|
||||
TF_Status* status);
|
||||
// tensors in `retvals` and `num_retvals`. Returns the corresponding graph
|
||||
// operation on success, otherwise returns nullptr.
|
||||
TF_CAPI_EXPORT extern TF_Operation* TFE_AddEagerOpToGraph(
|
||||
TFE_Op* op, TFE_TraceContext* trace_ctx, TFE_TensorHandle** retvals,
|
||||
int* num_retvals, TF_Status* status);
|
||||
|
||||
// Finalizes the trace graph and its inputs, and returns the number of inputs.
|
||||
// After this call, the next two APIs can be called to iterate over the input
|
||||
|
@ -319,50 +319,131 @@ TEST(CAPI_EXPERIMENTAL, SymbolicTensor) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI_EXPERIMENTAL, DebugPrintAndSymbolicExecution) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
class AddEagerOpToGraphTest : public ::testing::Test {
|
||||
protected:
|
||||
AddEagerOpToGraphTest()
|
||||
: status_(TF_NewStatus()),
|
||||
eager_ctx_(nullptr),
|
||||
graph_(TF_NewGraph()),
|
||||
trace_ctx_(TFE_NewTraceContext(graph_)) {
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
eager_ctx_ = TFE_NewContext(opts, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
}
|
||||
|
||||
~AddEagerOpToGraphTest() override {
|
||||
TFE_DeleteTraceContext(trace_ctx_);
|
||||
TF_DeleteGraph(graph_);
|
||||
TFE_DeleteContext(eager_ctx_);
|
||||
TF_DeleteStatus(status_);
|
||||
}
|
||||
|
||||
template <typename Callable>
|
||||
void AddEagerOpToGraphAndCheck(TFE_Op* op, Callable checker) {
|
||||
TFE_TensorHandle* retvals[5];
|
||||
int num_retvals = 5;
|
||||
// Symbolically execute this op, which adds a graph node to `trace_ctx_`.
|
||||
TF_Operation* graph_op =
|
||||
TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
CHECK_NOTNULL(graph_op);
|
||||
// Check the expectations.
|
||||
checker(graph_op);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
TFE_DeleteTensorHandle(retvals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TF_Status* status_;
|
||||
TFE_Context* eager_ctx_;
|
||||
TF_Graph* graph_;
|
||||
TFE_TraceContext* trace_ctx_;
|
||||
};
|
||||
|
||||
TEST_F(AddEagerOpToGraphTest, DebugPrintAndSymbolicExecution) {
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_Op* op = MatMulOp(ctx, m, m);
|
||||
TFE_Op* op = MatMulOp(eager_ctx_, m, m);
|
||||
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_OpPrintDebugString(op);
|
||||
|
||||
auto* graph = TF_NewGraph();
|
||||
auto* trace_ctx = TFE_NewTraceContext(graph);
|
||||
TFE_TensorHandle* retvals[5];
|
||||
int num_retvals = 5;
|
||||
// Symbolically execute this op, which adds a graph node to `trace_ctx`.
|
||||
TFE_AddEagerOpToGraph(op, trace_ctx, retvals, &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
|
||||
int num_inputs = TFE_FinalizeInputTensorsFromTraceContext(trace_ctx);
|
||||
int num_inputs = TFE_FinalizeInputTensorsFromTraceContext(trace_ctx_);
|
||||
CHECK_EQ(num_inputs, 1);
|
||||
auto input_sym_tensor = TFE_GetInputGraphNodeFromTraceContext(trace_ctx,
|
||||
auto input_sym_tensor = TFE_GetInputGraphNodeFromTraceContext(trace_ctx_,
|
||||
/*idx*/ 0);
|
||||
|
||||
LOG(INFO) << tensorflow::getTF_OutputDebugString(input_sym_tensor);
|
||||
auto handle = TFE_ConsumeInputConcreteTensorFromTraceContext(trace_ctx,
|
||||
auto handle = TFE_ConsumeInputConcreteTensorFromTraceContext(trace_ctx_,
|
||||
/*idx*/ 0);
|
||||
TFE_TensorHandlePrintDebugString(handle);
|
||||
TFE_DeleteTensorHandle(handle);
|
||||
|
||||
CHECK_EQ(num_retvals, 1);
|
||||
CHECK_EQ(TFE_TensorHandleDataType(retvals[0]), TF_FLOAT);
|
||||
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteTraceContext(trace_ctx);
|
||||
TF_DeleteGraph(graph);
|
||||
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteOp(op);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST_F(AddEagerOpToGraphTest, ValueAttributesArePreserved) {
|
||||
// Create MinOp
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
TFE_Op* op = MinOp(eager_ctx_, axis, axis);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
|
||||
// Check the attributes set by the call to MinOp above.
|
||||
AddEagerOpToGraphAndCheck(op, [this, &axis](TF_Operation* graph_op) {
|
||||
unsigned char value;
|
||||
TF_OperationGetAttrBool(graph_op, "keep_dims", &value, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
CHECK_EQ(value, 1);
|
||||
TF_DataType dtype;
|
||||
TF_OperationGetAttrType(graph_op, "Tidx", &dtype, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
CHECK_EQ(dtype, TF_INT32);
|
||||
TF_OperationGetAttrType(graph_op, "T", &dtype, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
CHECK_EQ(dtype, TFE_TensorHandleDataType(axis));
|
||||
});
|
||||
TFE_DeleteTensorHandle(axis);
|
||||
TFE_DeleteOp(op);
|
||||
}
|
||||
|
||||
TEST_F(AddEagerOpToGraphTest, ListAttributesArePreserved) {
|
||||
// Create a "Squeeze" operator with list attributes.
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
TFE_Op* squeeze = TFE_NewOp(eager_ctx_, "Squeeze", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_OpAddInput(squeeze, axis, status_);
|
||||
TFE_OpSetAttrType(squeeze, "T", TF_INT32);
|
||||
std::vector<int64_t> boundaries = {1, 2, 3, 4};
|
||||
TFE_OpSetAttrIntList(squeeze, "squeeze_dims", boundaries.data(),
|
||||
boundaries.size());
|
||||
// Check attributes are preserved.
|
||||
AddEagerOpToGraphAndCheck(
|
||||
squeeze, [this, &boundaries](TF_Operation* squeeze_graph_op) {
|
||||
TF_DataType dtype;
|
||||
TF_OperationGetAttrType(squeeze_graph_op, "T", &dtype, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
CHECK_EQ(dtype, TF_INT32);
|
||||
std::unique_ptr<int64_t[]> list(new int64_t[boundaries.size()]);
|
||||
TF_OperationGetAttrIntList(squeeze_graph_op, "squeeze_dims", list.get(),
|
||||
boundaries.size(), status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
EXPECT_TRUE(std::equal(list.get(), list.get() + boundaries.size(),
|
||||
boundaries.begin()));
|
||||
});
|
||||
TFE_DeleteTensorHandle(axis);
|
||||
TFE_DeleteOp(squeeze);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user