diff --git a/tensorflow/core/framework/graph_to_functiondef.cc b/tensorflow/core/framework/graph_to_functiondef.cc
index bbd70151849..e825aa722b5 100644
--- a/tensorflow/core/framework/graph_to_functiondef.cc
+++ b/tensorflow/core/framework/graph_to_functiondef.cc
@@ -434,9 +434,11 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
       // _Arg/Placeholder nodes.
       if (absl::StartsWith(attr.first, "_")) {
         arg_attrs.mutable_attr()->insert(attr);
-      } else if (attr.first == "shape") {
+      } else if (attr.first == "shape" && argdef->type() != DT_RESOURCE) {
         // Preserve known shapes by moving them to the _output_shapes list.
         // The _Arg shape function knows how to extract them from there.
+        // Don't preserve the shape of a resource arg node, which is a scalar
+        // resource handle.
         AttrValue value;
         *(value.mutable_list()->add_shape()) = attr.second.shape();
         arg_attrs.mutable_attr()->insert({"_output_shapes", value});
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index 780e3c7e3f2..a83fb824cc3 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -38,12 +38,14 @@ namespace grappler {
 
 GrapplerFunctionItem::GrapplerFunctionItem(
     string func_name, string description, AttrSlice func_attr,
+    std::vector<const FunctionDef::ArgAttrs*> arg_attr,
     std::vector<InputArgInstantiation> input_args,
     std::vector<OutputArgInstantiation> output_args,
     std::vector<ControlOutput> control_outputs, const int graph_def_version,
     const bool is_stateful, GraphDef&& function_body)
     : description_(std::move(description)),
       func_attr_(func_attr),
+      arg_attr_(std::move(arg_attr)),
       input_args_(std::move(input_args)),
       output_args_(std::move(output_args)),
       control_outputs_(std::move(control_outputs)),
@@ -108,6 +110,11 @@ const std::size_t GrapplerFunctionItem::control_output_size() const {
 
 const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; }
 
+const std::vector<const FunctionDef::ArgAttrs*>&
+GrapplerFunctionItem::arg_attr() const {
+  return arg_attr_;
+}
+
 const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
 
 GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
@@ -278,12 +285,17 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
     control_outputs.push_back({control_ret.first, control_ret.second});
   }
 
+  std::vector<const FunctionDef::ArgAttrs*> arg_attr(inputs.size(), nullptr);
+  for (const auto& attr : func.arg_attr()) {
+    arg_attr.at(attr.first) = &attr.second;
+  }
+
   *item = GrapplerFunctionItem(
       /*func_name=*/signature.name(),
       /*description=*/signature.description(),
-      /*func_attr=*/AttrSlice(&func.attr()), std::move(inputs),
-      std::move(outputs), std::move(control_outputs), graph_def_version,
-      signature.is_stateful(), std::move(function_body));
+      /*func_attr=*/AttrSlice(&func.attr()), std::move(arg_attr),
+      std::move(inputs), std::move(outputs), std::move(control_outputs),
+      graph_def_version, signature.is_stateful(), std::move(function_body));
   return Status::OK();
 }
 
@@ -330,6 +342,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
   }
 
   item->input_args_.erase(item->input_args_.begin() + input_index);
+  item->arg_attr_.erase(item->arg_attr_.begin() + input_index);
 
   return Status::OK();
 }
@@ -566,6 +579,14 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item,
     (*func->mutable_attr())[attr_name] = attr_value;
   }
 
+  // Copy function arg attributes.
+  for (int i = 0; i < item.arg_attr().size(); ++i) {
+    const auto* attr = item.arg_attr().at(i);
+    if (attr != nullptr) {
+      (*func->mutable_arg_attr())[i] = *attr;
+    }
+  }
+
   // Copy function body nodes to the FunctionDef and update input format
   for (const NodeDef& func_node : item.function_body().node()) {
     // Skip original `_Arg` and `_Retval` nodes. If node was converted to some
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index b03b89af2ab..2f1fd5d2ed6 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -76,6 +76,7 @@ class GrapplerFunctionItem : public GrapplerItem {
   const std::size_t control_output_size() const;
 
   const AttrSlice& func_attr() const;
+  const std::vector<const FunctionDef::ArgAttrs*>& arg_attr() const;
   const GraphDef& function_body() const;
   GraphDef& mutable_function_body();
 
@@ -95,6 +96,7 @@ class GrapplerFunctionItem : public GrapplerItem {
 
   GrapplerFunctionItem(string func_name, string description,
                        AttrSlice func_attr,
+                       std::vector<const FunctionDef::ArgAttrs*> arg_attr,
                        std::vector<InputArgInstantiation> input_args,
                        std::vector<OutputArgInstantiation> output_args,
                        std::vector<ControlOutput> control_outputs,
@@ -105,6 +107,9 @@ class GrapplerFunctionItem : public GrapplerItem {
   AttrSlice func_attr_;  // Attributes specific to function definition that
                          // produced this item (FuncDef.attr field).
 
+  // Attributes of function arguments
+  std::vector<const FunctionDef::ArgAttrs*> arg_attr_;
+
   std::vector<InputArgInstantiation> input_args_;
   std::vector<OutputArgInstantiation> output_args_;
   std::vector<ControlOutput> control_outputs_;
diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc
index 8cc938ec845..66320d60f27 100644
--- a/tensorflow/core/grappler/utils/functions_test.cc
+++ b/tensorflow/core/grappler/utils/functions_test.cc
@@ -523,6 +523,14 @@ TEST_F(FunctionsTest, MakeFunctionDef) {
           {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
       });
 
+  // Add an attribute to _Arg 0;
+  const uint32 arg_index = 0;
+  const std::pair<string, string> arg_attr_key_and_value = {"_arg_attr", "abc"};
+  FunctionDef::ArgAttrs arg_attr;
+  (*arg_attr.mutable_attr())[arg_attr_key_and_value.first].set_s(
+      arg_attr_key_and_value.second);
+  (*func.mutable_arg_attr())[arg_index] = arg_attr;
+
   protobuf::Map<string, AttrValue> func_instantiation_attr;
   func_instantiation_attr["T"].set_type(DT_FLOAT);
   FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
@@ -541,6 +549,15 @@ TEST_F(FunctionsTest, MakeFunctionDef) {
   EXPECT_EQ("y", specialized.signature().output_arg(0).name());
   EXPECT_EQ(DT_FLOAT, specialized.signature().output_arg(0).type());
 
+  EXPECT_EQ(specialized.arg_attr().size(), 1);
+  EXPECT_EQ(specialized.arg_attr().at(arg_index).attr().size(), 1);
+  EXPECT_EQ(specialized.arg_attr()
+                .at(arg_index)
+                .attr()
+                .at(arg_attr_key_and_value.first)
+                .s(),
+            arg_attr_key_and_value.second);
+
   // Function body specialized for instantiation types.
   int count = 0;
   for (const NodeDef &node : specialized.node_def()) {