From 906fac31889c94c3ba310191a4f4910d05eb1b9b Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Mon, 21 Sep 2020 13:59:46 -0700 Subject: [PATCH] Ensure that shapes are set for captured inputs to functions. Small eager tensors are captured as const ops, which don't have a shape attr. PiperOrigin-RevId: 332927858 Change-Id: Ic1ee86e2e5c7db6332d1b989580bdb599c382604 --- .../core/framework/graph_to_functiondef.cc | 9 +++++++++ .../optimizers/data/hoist_random_uniform.cc | 16 +++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/framework/graph_to_functiondef.cc b/tensorflow/core/framework/graph_to_functiondef.cc index e825aa722b5..786833f2d6a 100644 --- a/tensorflow/core/framework/graph_to_functiondef.cc +++ b/tensorflow/core/framework/graph_to_functiondef.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_node_util.h" @@ -442,6 +443,14 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, AttrValue value; *(value.mutable_list()->add_shape()) = attr.second.shape(); arg_attrs.mutable_attr()->insert({"_output_shapes", value}); + } else if (attr.first == "value" && node->type_string() == "Const") { + // Small eager tensors are captured as const ops rather than + // Placeholders. Add a _output_shapes arg_attr with the shape of the + // const tensor. + AttrValue value; + *(value.mutable_list()->add_shape()) = + attr.second.tensor().tensor_shape(); + arg_attrs.mutable_attr()->insert({"_output_shapes", value}); } if (attr.first == "_resource_arg_unique_id") { resource_arg_unique_id = attr.second.i(); diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc index e29b6201402..ebe8ff7522c 100644 --- a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc +++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc @@ -130,7 +130,8 @@ NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node, // We need to insert our argument before the placeholders, which are the last // arguments. -OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) { +OpDef_ArgDef* InsertSeedArgument(FunctionDef* function, int num_placeholders) { + OpDef* signature = function->mutable_signature(); int new_argument_idx = signature->input_arg_size() - num_placeholders; signature->add_input_arg(); for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) { @@ -140,6 +141,16 @@ OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) { seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx)); seed_arg->set_type(DT_INT64); + // Update arg_attr, any arg_attrs for the placeholders how have index one + // higher. + for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) { + if (function->arg_attr().contains(i - 1)) { + (*function->mutable_arg_attr())[i] = + (*function->mutable_arg_attr())[i - 1]; + function->mutable_arg_attr()->erase(i - 1); + } + } + return seed_arg; } @@ -157,8 +168,7 @@ const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function, graph_utils::SetUniqueGraphFunctionName("stateless_function", library, stateless_function); - auto* seed_arg = InsertSeedArgument(stateless_function->mutable_signature(), - num_placeholders); + auto* seed_arg = InsertSeedArgument(stateless_function, num_placeholders); auto* const random_uniform = stateless_function->mutable_node_def( function_utils::FindFunctionNodeWithOp("RandomUniform",