diff --git a/tensorflow/core/framework/graph_to_functiondef.cc b/tensorflow/core/framework/graph_to_functiondef.cc index bbd70151849..e825aa722b5 100644 --- a/tensorflow/core/framework/graph_to_functiondef.cc +++ b/tensorflow/core/framework/graph_to_functiondef.cc @@ -434,9 +434,11 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, // _Arg/Placeholder nodes. if (absl::StartsWith(attr.first, "_")) { arg_attrs.mutable_attr()->insert(attr); - } else if (attr.first == "shape") { + } else if (attr.first == "shape" && argdef->type() != DT_RESOURCE) { // Preserve known shapes by moving them to the _output_shapes list. // The _Arg shape function knows how to extract them from there. + // Don't preserve the shape of a resource arg node, which is a scalar + // resource handle. AttrValue value; *(value.mutable_list()->add_shape()) = attr.second.shape(); arg_attrs.mutable_attr()->insert({"_output_shapes", value}); diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index 780e3c7e3f2..a83fb824cc3 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -38,12 +38,14 @@ namespace grappler { GrapplerFunctionItem::GrapplerFunctionItem( string func_name, string description, AttrSlice func_attr, + std::vector<const FunctionDef::ArgAttrs*> arg_attr, std::vector<InputArgInstantiation> input_args, std::vector<OutputArgInstantiation> output_args, std::vector<ControlOutput> control_outputs, const int graph_def_version, const bool is_stateful, GraphDef&& function_body) : description_(std::move(description)), func_attr_(func_attr), + arg_attr_(std::move(arg_attr)), input_args_(std::move(input_args)), output_args_(std::move(output_args)), control_outputs_(std::move(control_outputs)), @@ -108,6 +110,11 @@ const std::size_t GrapplerFunctionItem::control_output_size() const { const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; } +const std::vector<const FunctionDef::ArgAttrs*>& +GrapplerFunctionItem::arg_attr() const { + return arg_attr_; +} + const GraphDef& GrapplerFunctionItem::function_body() const { return graph; } GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; } @@ -278,12 +285,17 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, control_outputs.push_back({control_ret.first, control_ret.second}); } + std::vector<const FunctionDef::ArgAttrs*> arg_attr(inputs.size(), nullptr); + for (const auto& attr : func.arg_attr()) { + arg_attr.at(attr.first) = &attr.second; + } + *item = GrapplerFunctionItem( /*func_name=*/signature.name(), /*description=*/signature.description(), - /*func_attr=*/AttrSlice(&func.attr()), std::move(inputs), - std::move(outputs), std::move(control_outputs), graph_def_version, - signature.is_stateful(), std::move(function_body)); + /*func_attr=*/AttrSlice(&func.attr()), std::move(arg_attr), + std::move(inputs), std::move(outputs), std::move(control_outputs), + graph_def_version, signature.is_stateful(), std::move(function_body)); return Status::OK(); } @@ -330,6 +342,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_index, } item->input_args_.erase(item->input_args_.begin() + input_index); + item->arg_attr_.erase(item->arg_attr_.begin() + input_index); return Status::OK(); } @@ -566,6 +579,14 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item, (*func->mutable_attr())[attr_name] = attr_value; } + // Copy function arg attributes. + for (int i = 0; i < item.arg_attr().size(); ++i) { + const auto* attr = item.arg_attr().at(i); + if (attr != nullptr) { + (*func->mutable_arg_attr())[i] = *attr; + } + } + // Copy function body nodes to the FunctionDef and update input format for (const NodeDef& func_node : item.function_body().node()) { // Skip original `_Arg` and `_Retval` nodes. If node was converted to some diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index b03b89af2ab..2f1fd5d2ed6 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -76,6 +76,7 @@ class GrapplerFunctionItem : public GrapplerItem { const std::size_t control_output_size() const; const AttrSlice& func_attr() const; + const std::vector<const FunctionDef::ArgAttrs*>& arg_attr() const; const GraphDef& function_body() const; GraphDef& mutable_function_body(); @@ -95,6 +96,7 @@ class GrapplerFunctionItem : public GrapplerItem { GrapplerFunctionItem(string func_name, string description, AttrSlice func_attr, + std::vector<const FunctionDef::ArgAttrs*> arg_attr, std::vector<InputArgInstantiation> input_args, std::vector<OutputArgInstantiation> output_args, std::vector<ControlOutput> control_outputs, @@ -105,6 +107,9 @@ class GrapplerFunctionItem : public GrapplerItem { AttrSlice func_attr_; // Attributes specific to function definition that // produced this item (FuncDef.attr field). + // Attributes of function arguments + std::vector<const FunctionDef::ArgAttrs*> arg_attr_; + std::vector<InputArgInstantiation> input_args_; std::vector<OutputArgInstantiation> output_args_; std::vector<ControlOutput> control_outputs_; diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc index 8cc938ec845..66320d60f27 100644 --- a/tensorflow/core/grappler/utils/functions_test.cc +++ b/tensorflow/core/grappler/utils/functions_test.cc @@ -523,6 +523,14 @@ TEST_F(FunctionsTest, MakeFunctionDef) { {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}}, }); + // Add an attribute to _Arg 0; + const uint32 arg_index = 0; + const std::pair<string, string> arg_attr_key_and_value = {"_arg_attr", "abc"}; + FunctionDef::ArgAttrs arg_attr; + (*arg_attr.mutable_attr())[arg_attr_key_and_value.first].set_s( + arg_attr_key_and_value.second); + (*func.mutable_arg_attr())[arg_index] = arg_attr; + protobuf::Map<string, AttrValue> func_instantiation_attr; func_instantiation_attr["T"].set_type(DT_FLOAT); FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); @@ -541,6 +549,15 @@ TEST_F(FunctionsTest, MakeFunctionDef) { EXPECT_EQ("y", specialized.signature().output_arg(0).name()); EXPECT_EQ(DT_FLOAT, specialized.signature().output_arg(0).type()); + EXPECT_EQ(specialized.arg_attr().size(), 1); + EXPECT_EQ(specialized.arg_attr().at(arg_index).attr().size(), 1); + EXPECT_EQ(specialized.arg_attr() + .at(arg_index) + .attr() + .at(arg_attr_key_and_value.first) + .s(), + arg_attr_key_and_value.second); + // Function body specialized for instantiation types. int count = 0; for (const NodeDef &node : specialized.node_def()) {