Fix handling of PartitionedCall in ShapeInference pass. Previously, the pass

assumed that the function call can only happen from the function input node,
and could not perform inter-procedural shape propagation through functions
called using PartitionedCall.

Note that in order to use this from a grappler pass P, P needs to override
`UsesFunctionLibrary` and return `true` from it.

PiperOrigin-RevId: 308886562
Change-Id: I1d64c07d9c5fbab2365725e2c213b95f2b21ae01
This commit is contained in:
George Karpenkov 2020-04-28 13:29:09 -07:00 committed by TensorFlower Gardener
parent 60073244fc
commit ec9e078470
4 changed files with 113 additions and 25 deletions

View File

@ -91,6 +91,7 @@ tf_cc_test(
deps = [
":graph_properties",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:scope",
"//tensorflow/core:framework",
"//tensorflow/core:lib",

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "absl/types/optional.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -590,6 +591,20 @@ bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) {
return kOpTpeWhitelist->find(op_type) != kOpTpeWhitelist->end();
}
// Negative shape size of '-1' represents unknown, while negative shape sizes
// less than -1 represent unknown symbolic shapes (e.g. the shape of [-5, 5, -1,
// -5] really means [x, 5, ?, x]). Before we can output the tensors as shapes,
// we need to normalize them: mark all values <-1 as "unknown" (-1).
static void NormalizeShapeForOutput(TensorShapeProto* shape) {
for (int i = 0; i < shape->dim_size(); i++) {
if (shape->dim(i).size() < -1) {
VLOG(2) << "Normalizing dimension: " << i << " from "
<< shape->dim(i).size() << " to -1";
shape->mutable_dim(i)->set_size(-1);
}
}
}
// Processes symbolic shapes.
// Each symbolic shape or dimension is represented by a handle. Unlike the TF
// shape refiner which creates new handles every time it processes an unknown
@ -722,7 +737,8 @@ class SymbolicShapeRefiner {
return it->second.inference_context.get();
}
// Forward the shapes from the function input nodes to
// Forward the shapes from the function input nodes, PartitionedCalls or
// StatefulPartitionedCall to
// the argument nodes (which are Placeholder nodes), then
// perform shape inference on the function body.
//
@ -732,10 +748,12 @@ class SymbolicShapeRefiner {
// In the event of an error, UpdateNode will simply set `function_node`'s
// output shape to be Unknown.
Status UpdateFunction(const NodeDef* function_node) {
auto it = fun_to_grappler_function_item_.find(function_node->op());
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*function_node, &function));
auto it = fun_to_grappler_function_item_.find(function.name());
if (it == fun_to_grappler_function_item_.end()) {
return errors::InvalidArgument(
function_node->op(),
function.name(),
" was not previously added to SymbolicShapeRefiner.");
}
@ -743,7 +761,7 @@ class SymbolicShapeRefiner {
it->second;
if (!maybe_grappler_function_item.has_value()) {
VLOG(3) << "Skip failed to instantiate function call: function_name="
<< function_node->op();
<< function.name();
auto* ctx = GetNodeContext(function_node);
auto* ic = ctx->inference_context.get();
@ -789,11 +807,7 @@ class SymbolicShapeRefiner {
const auto& handle = input_ic->output(output_port_num);
input_ic->ShapeHandleToProto(handle, &proto);
// There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
for (int i = 0; i < proto.dim_size(); i++) {
if (proto.dim(i).size() < -1) {
proto.mutable_dim(i)->set_size(-1);
}
}
NormalizeShapeForOutput(&proto);
AttrValue output_attr;
output_attr.mutable_list()->add_shape()->Swap(&proto);
@ -870,8 +884,9 @@ class SymbolicShapeRefiner {
out_tensor.ToString(), " has invalid position ", out_tensor.index(),
" (output_properties.size() = ", output_properties.size(), ").");
}
auto const& outprop = output_properties[out_tensor.index()];
const TensorShapeProto& shape = outprop.shape();
auto& outprop = output_properties[out_tensor.index()];
TensorShapeProto shape = outprop.shape();
NormalizeShapeForOutput(&shape);
ShapeHandle out;
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
ic->set_output(output, out);
@ -1196,15 +1211,14 @@ class SymbolicShapeRefiner {
return true;
}
Status AddFunction(const NodeDef* function_node) {
auto it = fun_to_grappler_function_item_.find(function_node->op());
Status AddFunction(const NodeDef* function_node, NameAttrList function) {
auto it = fun_to_grappler_function_item_.find(function.name());
if (it != fun_to_grappler_function_item_.end()) {
return Status::OK();
}
const FunctionDef* function_def =
CHECK_NOTNULL(function_library_.Find(function_node->op()));
CHECK_NOTNULL(function_library_.Find(function.name()));
GrapplerFunctionItem grappler_function_item;
Status function_instantiated =
MakeGrapplerFunctionItem(*function_def, function_library_,
@ -1242,10 +1256,15 @@ class SymbolicShapeRefiner {
Status AddNode(const NodeDef* node) {
NodeContext& node_ctx = node_to_context_[node];
TF_RETURN_IF_ERROR(function_library_.LookUp(node->op(), &node_ctx.op_data));
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*node, &function));
// For PartitionedCall, op_data represents the function info.
TF_RETURN_IF_ERROR(
function_library_.LookUp(function.name(), &node_ctx.op_data));
if (node_ctx.op_data->is_function_op) {
TF_RETURN_IF_ERROR(AddFunction(node));
TF_RETURN_IF_ERROR(AddFunction(node, function));
}
TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
@ -2525,13 +2544,7 @@ Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def,
TensorShapeProto* proto = attr_output_shape.mutable_list()->add_shape();
*proto = tensor_property.shape();
if (!allow_symbolic_shapes) {
// There may be dim.size < -1 in SymbolicShapeRefiner. Change those to
// -1.
for (int i = 0; i < proto->dim_size(); i++) {
if (proto->dim(i).size() < -1) {
proto->mutable_dim(i)->set_size(-1);
}
}
NormalizeShapeForOutput(proto);
}
}
(*node->mutable_attr())["_output_shapes"] = attr_output_shape;

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_builder.h"
@ -2002,6 +2003,79 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
EXPECT_EQ("float: [10,100]", PropToString(prop));
}
TEST_F(GraphPropertiesTest, PartitionedCallOp) {
Scope root = Scope::NewRootScope().ExitOnError();
FunctionDefLibrary library;
FunctionDef called_func = FunctionDefHelper::Create(
"identity_function",
/*in_def=*/{"arg0: int32"},
/*out_def=*/{"ret0: int32"},
/*attr_def=*/{},
{{{"Identity"}, "Identity", {"arg0"}, {{"T", DT_INT32}}}},
/*ret_def=*/{{"ret0", "Identity:output:0"}});
*library.add_function() = called_func;
TF_CHECK_OK(root.graph()->AddFunctionLibrary(library));
Output in = ops::Const(root, {3, 1, 2, 0});
NameAttrList b_name_attr;
b_name_attr.set_name("identity_function");
ops::PartitionedCall call(root.WithOpName("identity_call"), {in}, {DT_INT32},
b_name_attr);
GrapplerItem item;
TF_CHECK_OK(root.ToGraphDef(&item.graph));
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
EXPECT_EQ("int32: [4]",
PropToString(properties.GetOutputProperties("identity_call")[0]));
}
TEST_F(GraphPropertiesTest, NonTrivialInputPartitionedCallOp) {
auto f = FunctionDefHelper::Create(
// Name
"FunctionWhichAdds",
// Inputs
{"arg0: int32", "arg1: int32"},
// Outputs
{"ret0: int32"},
/*attr_def=*/{},
// Nodes
{{{"a"}, "Add", {"arg0", "arg1"}, {{"T", DT_INT32}}}},
/*ret_def=*/{{"ret0", "a:z:0"}});
FunctionDefLibrary function_lib;
function_lib.add_function()->Swap(&f);
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
TF_CHECK_OK(root.graph()->AddFunctionLibrary(function_lib));
PartialTensorShape input_shape({2, 2, -1});
Output in1 =
ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
Output in2 =
ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
NameAttrList b_name_attr;
b_name_attr.set_name("FunctionWhichAdds");
ops::PartitionedCall call(root.WithOpName("add_call"), {in1, in2}, {DT_INT32},
b_name_attr);
GrapplerItem item;
TF_CHECK_OK(root.ToGraphDef(&item.graph));
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
EXPECT_EQ("int32: [2,2,-1]",
PropToString(properties.GetOutputProperties("add_call")[0]));
}
TEST_F(GraphPropertiesTest, ShapeAnnotatedFunctionOp) {
// A function, which we cannot infer output shape statically.
auto f = FunctionDefHelper::Create(

View File

@ -43,7 +43,7 @@ class ScopedAllocatorOptimizer : public GraphOptimizer {
string name() const override { return "scoped_allocator_optimizer"; }
bool UsesFunctionLibrary() const override { return false; }
bool UsesFunctionLibrary() const override { return true; }
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) override;