Handle number attributes correctly in AddEagerOpToGraph API.
PiperOrigin-RevId: 242544702
This commit is contained in:
parent
115112f7fb
commit
6c71cb2d20
@ -799,8 +799,8 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||
const auto& op_type = op->operation.Name();
|
||||
auto op_name =
|
||||
tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++);
|
||||
auto* desc =
|
||||
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
|
||||
std::unique_ptr<TF_OperationDescription> desc(
|
||||
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()));
|
||||
|
||||
VLOG(1) << "Adding attrs.";
|
||||
tensorflow::AttrValueMap attrs;
|
||||
@ -814,30 +814,42 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||
size_t inputIndex = 0;
|
||||
const tensorflow::OpDef& op_def = desc->node_builder.op_def();
|
||||
for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) {
|
||||
// TODO(bgogul): Add support for number attributes.
|
||||
DCHECK(input_arg.number_attr().empty())
|
||||
<< "Number attributes is not implemented yet.";
|
||||
if (input_arg.type_list_attr().empty()) {
|
||||
if (input_arg.type_list_attr().empty() && input_arg.number_attr().empty()) {
|
||||
auto symbolic_input =
|
||||
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
TF_AddInput(desc, symbolic_input);
|
||||
TF_AddInput(desc.get(), symbolic_input);
|
||||
continue;
|
||||
}
|
||||
size_t list_size = 0;
|
||||
if (!input_arg.type_list_attr().empty()) {
|
||||
const std::string& type_list_attr = input_arg.type_list_attr();
|
||||
const auto& attr_value = attrs[type_list_attr];
|
||||
DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
|
||||
CHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
|
||||
<< "Type list attribute should be a list!";
|
||||
std::vector<TF_Output> list_inputs(attr_value.list().type_size());
|
||||
list_size = attr_value.list().type_size();
|
||||
} else {
|
||||
CHECK(!input_arg.number_attr().empty());
|
||||
const auto& attr_value = attrs[input_arg.number_attr()];
|
||||
CHECK(attr_value.value_case() == tensorflow::AttrValue::kI)
|
||||
<< "Number attribute should be int!";
|
||||
if (attr_value.i() < 0) {
|
||||
status->status = tensorflow::errors::Internal(
|
||||
"Number attribute for length should be >=0!");
|
||||
return nullptr;
|
||||
}
|
||||
list_size = attr_value.i();
|
||||
}
|
||||
std::vector<TF_Output> list_inputs(list_size);
|
||||
for (TF_Output& list_input : list_inputs) {
|
||||
list_input =
|
||||
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
}
|
||||
TF_AddInputList(desc, list_inputs.data(), list_inputs.size());
|
||||
TF_AddInputList(desc.get(), list_inputs.data(), list_inputs.size());
|
||||
}
|
||||
|
||||
auto* graph_op = TF_FinishOperation(desc, status);
|
||||
auto* graph_op = TF_FinishOperation(desc.release(), status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
VLOG(1) << "Op finalized; setting return tensors.";
|
||||
|
@ -376,5 +376,60 @@ TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) {
|
||||
TFE_DeleteOp(identityn);
|
||||
}
|
||||
|
||||
TEST_F(AddEagerOpToGraphTest, NumberAttributesAreHandledCorrectly) {
|
||||
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrInt(concatv2, "N", 2);
|
||||
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
|
||||
constexpr size_t kNumInputs = 2;
|
||||
for (size_t i = 0; i < kNumInputs; ++i) {
|
||||
TFE_OpAddInput(concatv2, matrix, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
}
|
||||
TFE_OpAddInput(concatv2, axis, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
AddEagerOpToGraphAndCheck(
|
||||
concatv2, [this, kNumInputs](TF_Operation* graph_op) {
|
||||
EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs + 1);
|
||||
int64_t attrN;
|
||||
TF_OperationGetAttrInt(graph_op, "N", &attrN, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
EXPECT_EQ(attrN, kNumInputs);
|
||||
EXPECT_EQ(TF_OperationInputListLength(graph_op, "values", status_),
|
||||
kNumInputs);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
});
|
||||
TFE_DeleteTensorHandle(axis);
|
||||
TFE_DeleteTensorHandle(matrix);
|
||||
TFE_DeleteOp(concatv2);
|
||||
}
|
||||
|
||||
TEST_F(AddEagerOpToGraphTest,
|
||||
GeneratesInternalErrorsForInvalidNumberAttributes) {
|
||||
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
int num_retvals = 5;
|
||||
TFE_TensorHandle* retvals[5];
|
||||
|
||||
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrInt(concatv2, "N", -1);
|
||||
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
|
||||
|
||||
TF_Operation* graph_op = TFE_AddEagerOpToGraph(concatv2, trace_ctx_, retvals,
|
||||
&num_retvals, status_);
|
||||
EXPECT_EQ(graph_op, nullptr);
|
||||
EXPECT_EQ(status_->status.error_message(),
|
||||
"Number attribute for length should be >=0!");
|
||||
|
||||
TFE_DeleteOp(concatv2);
|
||||
TFE_DeleteTensorHandle(axis);
|
||||
TFE_DeleteTensorHandle(matrix);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user