From ec9e07847006f33d97032b8f95783cbc9164b5f0 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Tue, 28 Apr 2020 13:29:09 -0700 Subject: [PATCH] Fix handling of PartitionedCall in ShapeInference pass. Previously, the pass assumed that the function call can only happen from the function input node, and could not perform inter-procedural shape propagation through functions called using PartitionedCall. Note that in order to use this from a grappler pass P, P needs to override `UsesFunctionLibrary` and return `true` from it. PiperOrigin-RevId: 308886562 Change-Id: I1d64c07d9c5fbab2365725e2c213b95f2b21ae01 --- tensorflow/core/grappler/costs/BUILD | 1 + .../core/grappler/costs/graph_properties.cc | 61 +++++++++------ .../grappler/costs/graph_properties_test.cc | 74 +++++++++++++++++++ .../optimizers/scoped_allocator_optimizer.h | 2 +- 4 files changed, 113 insertions(+), 25 deletions(-) diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 1c7493cad35..02a26cdd390 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -91,6 +91,7 @@ tf_cc_test( deps = [ ":graph_properties", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:scope", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index be987f2d151..db0965eec7b 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "absl/types/optional.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def_util.h" @@ -590,6 +591,20 @@ bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) { return kOpTpeWhitelist->find(op_type) != kOpTpeWhitelist->end(); } +// Negative shape size of '-1' represents unknown, while negative shape sizes +// less than -1 represent unknown symbolic shapes (e.g. the shape of [-5, 5, -1, +// -5] really means [x, 5, ?, x]). Before we can output the tensors as shapes, +// we need to normalize them: mark all values <-1 as "unknown" (-1). +static void NormalizeShapeForOutput(TensorShapeProto* shape) { + for (int i = 0; i < shape->dim_size(); i++) { + if (shape->dim(i).size() < -1) { + VLOG(2) << "Normalizing dimension: " << i << " from " + << shape->dim(i).size() << " to -1"; + shape->mutable_dim(i)->set_size(-1); + } + } +} + // Processes symbolic shapes. // Each symbolic shape or dimension is represented by a handle. Unlike the TF // shape refiner which creates new handles every time it processes an unknown @@ -722,7 +737,8 @@ class SymbolicShapeRefiner { return it->second.inference_context.get(); } - // Forward the shapes from the function input nodes to + // Forward the shapes from the function input nodes, PartitionedCalls or + // StatefulPartitionedCall to // the argument nodes (which are Placeholder nodes), then // perform shape inference on the function body. // @@ -732,10 +748,12 @@ class SymbolicShapeRefiner { // In the event of an error, UpdateNode will simply set `function_node`'s // output shape to be Unknown. Status UpdateFunction(const NodeDef* function_node) { - auto it = fun_to_grappler_function_item_.find(function_node->op()); + NameAttrList function; + TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*function_node, &function)); + auto it = fun_to_grappler_function_item_.find(function.name()); if (it == fun_to_grappler_function_item_.end()) { return errors::InvalidArgument( - function_node->op(), + function.name(), " was not previously added to SymbolicShapeRefiner."); } @@ -743,7 +761,7 @@ class SymbolicShapeRefiner { it->second; if (!maybe_grappler_function_item.has_value()) { VLOG(3) << "Skip failed to instantiate function call: function_name=" - << function_node->op(); + << function.name(); auto* ctx = GetNodeContext(function_node); auto* ic = ctx->inference_context.get(); @@ -789,11 +807,7 @@ class SymbolicShapeRefiner { const auto& handle = input_ic->output(output_port_num); input_ic->ShapeHandleToProto(handle, &proto); // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1. - for (int i = 0; i < proto.dim_size(); i++) { - if (proto.dim(i).size() < -1) { - proto.mutable_dim(i)->set_size(-1); - } - } + NormalizeShapeForOutput(&proto); AttrValue output_attr; output_attr.mutable_list()->add_shape()->Swap(&proto); @@ -870,8 +884,9 @@ class SymbolicShapeRefiner { out_tensor.ToString(), " has invalid position ", out_tensor.index(), " (output_properties.size() = ", output_properties.size(), ")."); } - auto const& outprop = output_properties[out_tensor.index()]; - const TensorShapeProto& shape = outprop.shape(); + auto& outprop = output_properties[out_tensor.index()]; + TensorShapeProto shape = outprop.shape(); + NormalizeShapeForOutput(&shape); ShapeHandle out; TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out)); ic->set_output(output, out); @@ -1196,15 +1211,14 @@ class SymbolicShapeRefiner { return true; } - Status AddFunction(const NodeDef* function_node) { - auto it = fun_to_grappler_function_item_.find(function_node->op()); + Status AddFunction(const NodeDef* function_node, NameAttrList function) { + auto it = fun_to_grappler_function_item_.find(function.name()); if (it != fun_to_grappler_function_item_.end()) { return Status::OK(); } const FunctionDef* function_def = - CHECK_NOTNULL(function_library_.Find(function_node->op())); - + CHECK_NOTNULL(function_library_.Find(function.name())); GrapplerFunctionItem grappler_function_item; Status function_instantiated = MakeGrapplerFunctionItem(*function_def, function_library_, @@ -1242,10 +1256,15 @@ class SymbolicShapeRefiner { Status AddNode(const NodeDef* node) { NodeContext& node_ctx = node_to_context_[node]; - TF_RETURN_IF_ERROR(function_library_.LookUp(node->op(), &node_ctx.op_data)); + NameAttrList function; + TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*node, &function)); + + // For PartitionedCall, op_data represents the function info. + TF_RETURN_IF_ERROR( + function_library_.LookUp(function.name(), &node_ctx.op_data)); if (node_ctx.op_data->is_function_op) { - TF_RETURN_IF_ERROR(AddFunction(node)); + TF_RETURN_IF_ERROR(AddFunction(node, function)); } TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def, @@ -2525,13 +2544,7 @@ Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def, TensorShapeProto* proto = attr_output_shape.mutable_list()->add_shape(); *proto = tensor_property.shape(); if (!allow_symbolic_shapes) { - // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to - // -1. - for (int i = 0; i < proto->dim_size(); i++) { - if (proto->dim(i).size() < -1) { - proto->mutable_dim(i)->set_size(-1); - } - } + NormalizeShapeForOutput(proto); } } (*node->mutable_attr())["_output_shapes"] = attr_output_shape; diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 135fc521668..4d785e1b3a6 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -2002,6 +2003,79 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) { EXPECT_EQ("float: [10,100]", PropToString(prop)); } +TEST_F(GraphPropertiesTest, PartitionedCallOp) { + Scope root = Scope::NewRootScope().ExitOnError(); + FunctionDefLibrary library; + FunctionDef called_func = FunctionDefHelper::Create( + "identity_function", + /*in_def=*/{"arg0: int32"}, + /*out_def=*/{"ret0: int32"}, + /*attr_def=*/{}, + {{{"Identity"}, "Identity", {"arg0"}, {{"T", DT_INT32}}}}, + /*ret_def=*/{{"ret0", "Identity:output:0"}}); + *library.add_function() = called_func; + TF_CHECK_OK(root.graph()->AddFunctionLibrary(library)); + + Output in = ops::Const(root, {3, 1, 2, 0}); + NameAttrList b_name_attr; + b_name_attr.set_name("identity_function"); + ops::PartitionedCall call(root.WithOpName("identity_call"), {in}, {DT_INT32}, + b_name_attr); + + GrapplerItem item; + TF_CHECK_OK(root.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically( + /*assume_valid_feeds=*/true, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/true)); + + EXPECT_EQ("int32: [4]", + PropToString(properties.GetOutputProperties("identity_call")[0])); +} + +TEST_F(GraphPropertiesTest, NonTrivialInputPartitionedCallOp) { + auto f = FunctionDefHelper::Create( + // Name + "FunctionWhichAdds", + // Inputs + {"arg0: int32", "arg1: int32"}, + // Outputs + {"ret0: int32"}, + /*attr_def=*/{}, + // Nodes + {{{"a"}, "Add", {"arg0", "arg1"}, {{"T", DT_INT32}}}}, + /*ret_def=*/{{"ret0", "a:z:0"}}); + + FunctionDefLibrary function_lib; + function_lib.add_function()->Swap(&f); + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + TF_CHECK_OK(root.graph()->AddFunctionLibrary(function_lib)); + + PartialTensorShape input_shape({2, 2, -1}); + Output in1 = + ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape)); + Output in2 = + ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape)); + NameAttrList b_name_attr; + b_name_attr.set_name("FunctionWhichAdds"); + ops::PartitionedCall call(root.WithOpName("add_call"), {in1, in2}, {DT_INT32}, + b_name_attr); + + GrapplerItem item; + TF_CHECK_OK(root.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically( + /*assume_valid_feeds=*/true, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/true)); + + EXPECT_EQ("int32: [2,2,-1]", + PropToString(properties.GetOutputProperties("add_call")[0])); +} + TEST_F(GraphPropertiesTest, ShapeAnnotatedFunctionOp) { // A function, which we cannot infer output shape statically. auto f = FunctionDefHelper::Create( diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h index acc28f934dc..96bc90c8a95 100644 --- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h +++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h @@ -43,7 +43,7 @@ class ScopedAllocatorOptimizer : public GraphOptimizer { string name() const override { return "scoped_allocator_optimizer"; } - bool UsesFunctionLibrary() const override { return false; } + bool UsesFunctionLibrary() const override { return true; } Status Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) override;