Handle number attributes correctly in AddEagerOpToGraph API.

PiperOrigin-RevId: 242544702
This commit is contained in:
A. Unique TensorFlower 2019-04-08 14:55:37 -07:00 committed by TensorFlower Gardener
parent 115112f7fb
commit 6c71cb2d20
2 changed files with 81 additions and 14 deletions

View File

@ -799,8 +799,8 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
const auto& op_type = op->operation.Name();
auto op_name =
tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++);
auto* desc =
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
std::unique_ptr<TF_OperationDescription> desc(
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()));
VLOG(1) << "Adding attrs.";
tensorflow::AttrValueMap attrs;
@ -814,30 +814,42 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
size_t inputIndex = 0;
const tensorflow::OpDef& op_def = desc->node_builder.op_def();
for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) {
// TODO(bgogul): Add support for number attributes.
DCHECK(input_arg.number_attr().empty())
<< "Number attributes is not implemented yet.";
if (input_arg.type_list_attr().empty()) {
if (input_arg.type_list_attr().empty() && input_arg.number_attr().empty()) {
auto symbolic_input =
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
if (!status->status.ok()) return nullptr;
TF_AddInput(desc, symbolic_input);
TF_AddInput(desc.get(), symbolic_input);
continue;
}
const std::string& type_list_attr = input_arg.type_list_attr();
const auto& attr_value = attrs[type_list_attr];
DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
<< "Type list attribute should be a list!";
std::vector<TF_Output> list_inputs(attr_value.list().type_size());
size_t list_size = 0;
if (!input_arg.type_list_attr().empty()) {
const std::string& type_list_attr = input_arg.type_list_attr();
const auto& attr_value = attrs[type_list_attr];
CHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
<< "Type list attribute should be a list!";
list_size = attr_value.list().type_size();
} else {
CHECK(!input_arg.number_attr().empty());
const auto& attr_value = attrs[input_arg.number_attr()];
CHECK(attr_value.value_case() == tensorflow::AttrValue::kI)
<< "Number attribute should be int!";
if (attr_value.i() < 0) {
status->status = tensorflow::errors::Internal(
"Number attribute for length should be >=0!");
return nullptr;
}
list_size = attr_value.i();
}
std::vector<TF_Output> list_inputs(list_size);
for (TF_Output& list_input : list_inputs) {
list_input =
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
if (!status->status.ok()) return nullptr;
}
TF_AddInputList(desc, list_inputs.data(), list_inputs.size());
TF_AddInputList(desc.get(), list_inputs.data(), list_inputs.size());
}
auto* graph_op = TF_FinishOperation(desc, status);
auto* graph_op = TF_FinishOperation(desc.release(), status);
if (!status->status.ok()) return nullptr;
VLOG(1) << "Op finalized; setting return tensors.";

View File

@ -376,5 +376,60 @@ TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) {
TFE_DeleteOp(identityn);
}
TEST_F(AddEagerOpToGraphTest, NumberAttributesAreHandledCorrectly) {
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
TFE_OpSetAttrInt(concatv2, "N", 2);
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
constexpr size_t kNumInputs = 2;
for (size_t i = 0; i < kNumInputs; ++i) {
TFE_OpAddInput(concatv2, matrix, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
}
TFE_OpAddInput(concatv2, axis, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
AddEagerOpToGraphAndCheck(
concatv2, [this, kNumInputs](TF_Operation* graph_op) {
EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs + 1);
int64_t attrN;
TF_OperationGetAttrInt(graph_op, "N", &attrN, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
EXPECT_EQ(attrN, kNumInputs);
EXPECT_EQ(TF_OperationInputListLength(graph_op, "values", status_),
kNumInputs);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
});
TFE_DeleteTensorHandle(axis);
TFE_DeleteTensorHandle(matrix);
TFE_DeleteOp(concatv2);
}
TEST_F(AddEagerOpToGraphTest,
GeneratesInternalErrorsForInvalidNumberAttributes) {
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
int num_retvals = 5;
TFE_TensorHandle* retvals[5];
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
TFE_OpSetAttrInt(concatv2, "N", -1);
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
TF_Operation* graph_op = TFE_AddEagerOpToGraph(concatv2, trace_ctx_, retvals,
&num_retvals, status_);
EXPECT_EQ(graph_op, nullptr);
EXPECT_EQ(status_->status.error_message(),
"Number attribute for length should be >=0!");
TFE_DeleteOp(concatv2);
TFE_DeleteTensorHandle(axis);
TFE_DeleteTensorHandle(matrix);
}
} // namespace
} // namespace tensorflow