Handle number attributes correctly in AddEagerOpToGraph API.
PiperOrigin-RevId: 242544702
This commit is contained in:
parent
115112f7fb
commit
6c71cb2d20
@ -799,8 +799,8 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
|||||||
const auto& op_type = op->operation.Name();
|
const auto& op_type = op->operation.Name();
|
||||||
auto op_name =
|
auto op_name =
|
||||||
tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++);
|
tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++);
|
||||||
auto* desc =
|
std::unique_ptr<TF_OperationDescription> desc(
|
||||||
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
|
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()));
|
||||||
|
|
||||||
VLOG(1) << "Adding attrs.";
|
VLOG(1) << "Adding attrs.";
|
||||||
tensorflow::AttrValueMap attrs;
|
tensorflow::AttrValueMap attrs;
|
||||||
@ -814,30 +814,42 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
|||||||
size_t inputIndex = 0;
|
size_t inputIndex = 0;
|
||||||
const tensorflow::OpDef& op_def = desc->node_builder.op_def();
|
const tensorflow::OpDef& op_def = desc->node_builder.op_def();
|
||||||
for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) {
|
for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) {
|
||||||
// TODO(bgogul): Add support for number attributes.
|
if (input_arg.type_list_attr().empty() && input_arg.number_attr().empty()) {
|
||||||
DCHECK(input_arg.number_attr().empty())
|
|
||||||
<< "Number attributes is not implemented yet.";
|
|
||||||
if (input_arg.type_list_attr().empty()) {
|
|
||||||
auto symbolic_input =
|
auto symbolic_input =
|
||||||
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
||||||
if (!status->status.ok()) return nullptr;
|
if (!status->status.ok()) return nullptr;
|
||||||
TF_AddInput(desc, symbolic_input);
|
TF_AddInput(desc.get(), symbolic_input);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const std::string& type_list_attr = input_arg.type_list_attr();
|
size_t list_size = 0;
|
||||||
const auto& attr_value = attrs[type_list_attr];
|
if (!input_arg.type_list_attr().empty()) {
|
||||||
DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
|
const std::string& type_list_attr = input_arg.type_list_attr();
|
||||||
<< "Type list attribute should be a list!";
|
const auto& attr_value = attrs[type_list_attr];
|
||||||
std::vector<TF_Output> list_inputs(attr_value.list().type_size());
|
CHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
|
||||||
|
<< "Type list attribute should be a list!";
|
||||||
|
list_size = attr_value.list().type_size();
|
||||||
|
} else {
|
||||||
|
CHECK(!input_arg.number_attr().empty());
|
||||||
|
const auto& attr_value = attrs[input_arg.number_attr()];
|
||||||
|
CHECK(attr_value.value_case() == tensorflow::AttrValue::kI)
|
||||||
|
<< "Number attribute should be int!";
|
||||||
|
if (attr_value.i() < 0) {
|
||||||
|
status->status = tensorflow::errors::Internal(
|
||||||
|
"Number attribute for length should be >=0!");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
list_size = attr_value.i();
|
||||||
|
}
|
||||||
|
std::vector<TF_Output> list_inputs(list_size);
|
||||||
for (TF_Output& list_input : list_inputs) {
|
for (TF_Output& list_input : list_inputs) {
|
||||||
list_input =
|
list_input =
|
||||||
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
||||||
if (!status->status.ok()) return nullptr;
|
if (!status->status.ok()) return nullptr;
|
||||||
}
|
}
|
||||||
TF_AddInputList(desc, list_inputs.data(), list_inputs.size());
|
TF_AddInputList(desc.get(), list_inputs.data(), list_inputs.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* graph_op = TF_FinishOperation(desc, status);
|
auto* graph_op = TF_FinishOperation(desc.release(), status);
|
||||||
if (!status->status.ok()) return nullptr;
|
if (!status->status.ok()) return nullptr;
|
||||||
|
|
||||||
VLOG(1) << "Op finalized; setting return tensors.";
|
VLOG(1) << "Op finalized; setting return tensors.";
|
||||||
|
@ -376,5 +376,60 @@ TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) {
|
|||||||
TFE_DeleteOp(identityn);
|
TFE_DeleteOp(identityn);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AddEagerOpToGraphTest, NumberAttributesAreHandledCorrectly) {
|
||||||
|
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
|
||||||
|
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||||
|
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
|
||||||
|
TFE_OpSetAttrInt(concatv2, "N", 2);
|
||||||
|
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
|
||||||
|
constexpr size_t kNumInputs = 2;
|
||||||
|
for (size_t i = 0; i < kNumInputs; ++i) {
|
||||||
|
TFE_OpAddInput(concatv2, matrix, status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
}
|
||||||
|
TFE_OpAddInput(concatv2, axis, status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
AddEagerOpToGraphAndCheck(
|
||||||
|
concatv2, [this, kNumInputs](TF_Operation* graph_op) {
|
||||||
|
EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs + 1);
|
||||||
|
int64_t attrN;
|
||||||
|
TF_OperationGetAttrInt(graph_op, "N", &attrN, status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
EXPECT_EQ(attrN, kNumInputs);
|
||||||
|
EXPECT_EQ(TF_OperationInputListLength(graph_op, "values", status_),
|
||||||
|
kNumInputs);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
});
|
||||||
|
TFE_DeleteTensorHandle(axis);
|
||||||
|
TFE_DeleteTensorHandle(matrix);
|
||||||
|
TFE_DeleteOp(concatv2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AddEagerOpToGraphTest,
|
||||||
|
GeneratesInternalErrorsForInvalidNumberAttributes) {
|
||||||
|
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
|
||||||
|
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||||
|
int num_retvals = 5;
|
||||||
|
TFE_TensorHandle* retvals[5];
|
||||||
|
|
||||||
|
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
|
||||||
|
TFE_OpSetAttrInt(concatv2, "N", -1);
|
||||||
|
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
|
||||||
|
|
||||||
|
TF_Operation* graph_op = TFE_AddEagerOpToGraph(concatv2, trace_ctx_, retvals,
|
||||||
|
&num_retvals, status_);
|
||||||
|
EXPECT_EQ(graph_op, nullptr);
|
||||||
|
EXPECT_EQ(status_->status.error_message(),
|
||||||
|
"Number attribute for length should be >=0!");
|
||||||
|
|
||||||
|
TFE_DeleteOp(concatv2);
|
||||||
|
TFE_DeleteTensorHandle(axis);
|
||||||
|
TFE_DeleteTensorHandle(matrix);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user