diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 41bc34a0e9c..6da2a02a4fb 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -799,8 +799,8 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, const auto& op_type = op->operation.Name(); auto op_name = tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++); - auto* desc = - TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()); + std::unique_ptr desc( + TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str())); VLOG(1) << "Adding attrs."; tensorflow::AttrValueMap attrs; @@ -814,30 +814,42 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, size_t inputIndex = 0; const tensorflow::OpDef& op_def = desc->node_builder.op_def(); for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) { - // TODO(bgogul): Add support for number attributes. - DCHECK(input_arg.number_attr().empty()) - << "Number attributes is not implemented yet."; - if (input_arg.type_list_attr().empty()) { + if (input_arg.type_list_attr().empty() && input_arg.number_attr().empty()) { auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status); if (!status->status.ok()) return nullptr; - TF_AddInput(desc, symbolic_input); + TF_AddInput(desc.get(), symbolic_input); continue; } - const std::string& type_list_attr = input_arg.type_list_attr(); - const auto& attr_value = attrs[type_list_attr]; - DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList) - << "Type list attribute should be a list!"; - std::vector list_inputs(attr_value.list().type_size()); + size_t list_size = 0; + if (!input_arg.type_list_attr().empty()) { + const std::string& type_list_attr = input_arg.type_list_attr(); + const auto& attr_value = attrs[type_list_attr]; + CHECK(attr_value.value_case() == tensorflow::AttrValue::kList) + << "Type list attribute should be a list!"; + list_size = attr_value.list().type_size(); + } else { + CHECK(!input_arg.number_attr().empty()); + const auto& attr_value = attrs[input_arg.number_attr()]; + CHECK(attr_value.value_case() == tensorflow::AttrValue::kI) + << "Number attribute should be int!"; + if (attr_value.i() < 0) { + status->status = tensorflow::errors::Internal( + "Number attribute for length should be >=0!"); + return nullptr; + } + list_size = attr_value.i(); + } + std::vector list_inputs(list_size); for (TF_Output& list_input : list_inputs) { list_input = getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status); if (!status->status.ok()) return nullptr; } - TF_AddInputList(desc, list_inputs.data(), list_inputs.size()); + TF_AddInputList(desc.get(), list_inputs.data(), list_inputs.size()); } - auto* graph_op = TF_FinishOperation(desc, status); + auto* graph_op = TF_FinishOperation(desc.release(), status); if (!status->status.ok()) return nullptr; VLOG(1) << "Op finalized; setting return tensors."; diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index c2b8f3f7631..6eb289107c5 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -376,5 +376,60 @@ TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) { TFE_DeleteOp(identityn); } +TEST_F(AddEagerOpToGraphTest, NumberAttributesAreHandledCorrectly) { + TFE_TensorHandle* matrix = TestMatrixTensorHandle(); + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpSetAttrType(concatv2, "T", TF_FLOAT); + TFE_OpSetAttrInt(concatv2, "N", 2); + TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32); + constexpr size_t kNumInputs = 2; + for (size_t i = 0; i < kNumInputs; ++i) { + TFE_OpAddInput(concatv2, matrix, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + } + TFE_OpAddInput(concatv2, axis, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + AddEagerOpToGraphAndCheck( + concatv2, [this, kNumInputs](TF_Operation* graph_op) { + EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs + 1); + int64_t attrN; + TF_OperationGetAttrInt(graph_op, "N", &attrN, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + EXPECT_EQ(attrN, kNumInputs); + EXPECT_EQ(TF_OperationInputListLength(graph_op, "values", status_), + kNumInputs); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + }); + TFE_DeleteTensorHandle(axis); + TFE_DeleteTensorHandle(matrix); + TFE_DeleteOp(concatv2); +} + +TEST_F(AddEagerOpToGraphTest, + GeneratesInternalErrorsForInvalidNumberAttributes) { + TFE_TensorHandle* matrix = TestMatrixTensorHandle(); + TFE_TensorHandle* axis = TestAxisTensorHandle(); + int num_retvals = 5; + TFE_TensorHandle* retvals[5]; + + TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpSetAttrType(concatv2, "T", TF_FLOAT); + TFE_OpSetAttrInt(concatv2, "N", -1); + TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32); + + TF_Operation* graph_op = TFE_AddEagerOpToGraph(concatv2, trace_ctx_, retvals, + &num_retvals, status_); + EXPECT_EQ(graph_op, nullptr); + EXPECT_EQ(status_->status.error_message(), + "Number attribute for length should be >=0!"); + + TFE_DeleteOp(concatv2); + TFE_DeleteTensorHandle(axis); + TFE_DeleteTensorHandle(matrix); +} + } // namespace } // namespace tensorflow