Ensure that shapes are set for captured inputs to functions. Small eager tensors are captured as const ops, which don't have a shape attr.

PiperOrigin-RevId: 332927858
Change-Id: Ic1ee86e2e5c7db6332d1b989580bdb599c382604
This commit is contained in:
Bruce Fontaine 2020-09-21 13:59:46 -07:00 committed by TensorFlower Gardener
parent 65139ecc00
commit 906fac3188
2 changed files with 22 additions and 3 deletions

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_node_util.h"
@ -442,6 +443,14 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
AttrValue value;
*(value.mutable_list()->add_shape()) = attr.second.shape();
arg_attrs.mutable_attr()->insert({"_output_shapes", value});
} else if (attr.first == "value" && node->type_string() == "Const") {
// Small eager tensors are captured as const ops rather than
// Placeholders. Add a _output_shapes arg_attr with the shape of the
// const tensor.
AttrValue value;
*(value.mutable_list()->add_shape()) =
attr.second.tensor().tensor_shape();
arg_attrs.mutable_attr()->insert({"_output_shapes", value});
}
if (attr.first == "_resource_arg_unique_id") {
resource_arg_unique_id = attr.second.i();

View File

@ -130,7 +130,8 @@ NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
// We need to insert our argument before the placeholders, which are the last
// arguments.
OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) {
OpDef_ArgDef* InsertSeedArgument(FunctionDef* function, int num_placeholders) {
OpDef* signature = function->mutable_signature();
int new_argument_idx = signature->input_arg_size() - num_placeholders;
signature->add_input_arg();
for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) {
@ -140,6 +141,16 @@ OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) {
seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx));
seed_arg->set_type(DT_INT64);
// Update arg_attr, any arg_attrs for the placeholders how have index one
// higher.
for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) {
if (function->arg_attr().contains(i - 1)) {
(*function->mutable_arg_attr())[i] =
(*function->mutable_arg_attr())[i - 1];
function->mutable_arg_attr()->erase(i - 1);
}
}
return seed_arg;
}
@ -157,8 +168,7 @@ const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function,
graph_utils::SetUniqueGraphFunctionName("stateless_function", library,
stateless_function);
auto* seed_arg = InsertSeedArgument(stateless_function->mutable_signature(),
num_placeholders);
auto* seed_arg = InsertSeedArgument(stateless_function, num_placeholders);
auto* const random_uniform = stateless_function->mutable_node_def(
function_utils::FindFunctionNodeWithOp("RandomUniform",