Handle list inputs correctly in AddEagerOpToGraph API.

PiperOrigin-RevId: 234166416
This commit is contained in:
A. Unique TensorFlower 2019-02-15 10:16:46 -08:00 committed by TensorFlower Gardener
parent 12c5e6c4ce
commit 896ad1053b
2 changed files with 52 additions and 5 deletions

View File

@ -9064,11 +9064,6 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++);
auto* desc =
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
for (auto* input : op->operation.Inputs()) {
auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, input, status);
if (!status->status.ok()) return nullptr;
TF_AddInput(desc, symbolic_input);
}
VLOG(1) << "Adding attrs.";
tensorflow::AttrValueMap attrs;
@ -9077,6 +9072,34 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
desc->node_builder.Attr(attr.first, attr.second);
}
VLOG(1) << "Adding inputs.";
const auto& inputs = op->operation.Inputs();
size_t inputIndex = 0;
const tensorflow::OpDef& op_def = desc->node_builder.op_def();
for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) {
// TODO(bgogul): Add support for number attributes.
DCHECK(input_arg.number_attr().empty())
<< "Number attributes is not implemented yet.";
if (input_arg.type_list_attr().empty()) {
auto symbolic_input =
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
if (!status->status.ok()) return nullptr;
TF_AddInput(desc, symbolic_input);
continue;
}
const std::string& type_list_attr = input_arg.type_list_attr();
const auto& attr_value = attrs[type_list_attr];
DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
<< "Type list attribute should be a list!";
std::vector<TF_Output> list_inputs(attr_value.list().type_size());
for (TF_Output& list_input : list_inputs) {
list_input =
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
if (!status->status.ok()) return nullptr;
}
TF_AddInputList(desc, list_inputs.data(), list_inputs.size());
}
auto* graph_op = TF_FinishOperation(desc, status);
if (!status->status.ok()) return nullptr;

View File

@ -446,5 +446,29 @@ TEST_F(AddEagerOpToGraphTest, ListAttributesArePreserved) {
TFE_DeleteOp(squeeze);
}
TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) {
TFE_TensorHandle* scalar = TestScalarTensorHandle();
TFE_Op* identityn = TFE_NewOp(eager_ctx_, "IdentityN", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
constexpr size_t kNumInputs = 3;
for (size_t i = 0; i < kNumInputs; ++i) {
TFE_OpAddInput(identityn, scalar, status_);
}
TF_DataType types[kNumInputs] = {TF_FLOAT, TF_FLOAT, TF_FLOAT};
TFE_OpSetAttrTypeList(identityn, "T", types, kNumInputs);
AddEagerOpToGraphAndCheck(
identityn, [this, kNumInputs](TF_Operation* graph_op) {
EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs);
EXPECT_EQ(TF_OperationInputListLength(graph_op, "input", status_),
kNumInputs);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
EXPECT_EQ(TF_OperationOutputListLength(graph_op, "output", status_),
kNumInputs);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
});
TFE_DeleteTensorHandle(scalar);
TFE_DeleteOp(identityn);
}
} // namespace
} // namespace tensorflow