Handle list inputs correctly in AddEagerOpToGraph API.
PiperOrigin-RevId: 234166416
This commit is contained in:
parent
12c5e6c4ce
commit
896ad1053b
@ -9064,11 +9064,6 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||
tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++);
|
||||
auto* desc =
|
||||
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
|
||||
for (auto* input : op->operation.Inputs()) {
|
||||
auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, input, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
TF_AddInput(desc, symbolic_input);
|
||||
}
|
||||
|
||||
VLOG(1) << "Adding attrs.";
|
||||
tensorflow::AttrValueMap attrs;
|
||||
@ -9077,6 +9072,34 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||
desc->node_builder.Attr(attr.first, attr.second);
|
||||
}
|
||||
|
||||
VLOG(1) << "Adding inputs.";
|
||||
const auto& inputs = op->operation.Inputs();
|
||||
size_t inputIndex = 0;
|
||||
const tensorflow::OpDef& op_def = desc->node_builder.op_def();
|
||||
for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) {
|
||||
// TODO(bgogul): Add support for number attributes.
|
||||
DCHECK(input_arg.number_attr().empty())
|
||||
<< "Number attributes is not implemented yet.";
|
||||
if (input_arg.type_list_attr().empty()) {
|
||||
auto symbolic_input =
|
||||
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
TF_AddInput(desc, symbolic_input);
|
||||
continue;
|
||||
}
|
||||
const std::string& type_list_attr = input_arg.type_list_attr();
|
||||
const auto& attr_value = attrs[type_list_attr];
|
||||
DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
|
||||
<< "Type list attribute should be a list!";
|
||||
std::vector<TF_Output> list_inputs(attr_value.list().type_size());
|
||||
for (TF_Output& list_input : list_inputs) {
|
||||
list_input =
|
||||
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
}
|
||||
TF_AddInputList(desc, list_inputs.data(), list_inputs.size());
|
||||
}
|
||||
|
||||
auto* graph_op = TF_FinishOperation(desc, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
|
@ -446,5 +446,29 @@ TEST_F(AddEagerOpToGraphTest, ListAttributesArePreserved) {
|
||||
TFE_DeleteOp(squeeze);
|
||||
}
|
||||
|
||||
TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) {
|
||||
TFE_TensorHandle* scalar = TestScalarTensorHandle();
|
||||
TFE_Op* identityn = TFE_NewOp(eager_ctx_, "IdentityN", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
constexpr size_t kNumInputs = 3;
|
||||
for (size_t i = 0; i < kNumInputs; ++i) {
|
||||
TFE_OpAddInput(identityn, scalar, status_);
|
||||
}
|
||||
TF_DataType types[kNumInputs] = {TF_FLOAT, TF_FLOAT, TF_FLOAT};
|
||||
TFE_OpSetAttrTypeList(identityn, "T", types, kNumInputs);
|
||||
AddEagerOpToGraphAndCheck(
|
||||
identityn, [this, kNumInputs](TF_Operation* graph_op) {
|
||||
EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs);
|
||||
EXPECT_EQ(TF_OperationInputListLength(graph_op, "input", status_),
|
||||
kNumInputs);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
EXPECT_EQ(TF_OperationOutputListLength(graph_op, "output", status_),
|
||||
kNumInputs);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
});
|
||||
TFE_DeleteTensorHandle(scalar);
|
||||
TFE_DeleteOp(identityn);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user