Fix handling of PartitionedCall in ShapeInference pass. Previously, the pass
assumed that the function call can only happen from the function input node, and could not perform inter-procedural shape propagation through functions called using PartitionedCall. Note that in order to use this from a grappler pass P, P needs to override `UsesFunctionLibrary` and return `true` from it. PiperOrigin-RevId: 308886562 Change-Id: I1d64c07d9c5fbab2365725e2c213b95f2b21ae01
This commit is contained in:
parent
60073244fc
commit
ec9e078470
tensorflow/core/grappler
@ -91,6 +91,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":graph_properties",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
@ -590,6 +591,20 @@ bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) {
|
||||
return kOpTpeWhitelist->find(op_type) != kOpTpeWhitelist->end();
|
||||
}
|
||||
|
||||
// Negative shape size of '-1' represents unknown, while negative shape sizes
|
||||
// less than -1 represent unknown symbolic shapes (e.g. the shape of [-5, 5, -1,
|
||||
// -5] really means [x, 5, ?, x]). Before we can output the tensors as shapes,
|
||||
// we need to normalize them: mark all values <-1 as "unknown" (-1).
|
||||
static void NormalizeShapeForOutput(TensorShapeProto* shape) {
|
||||
for (int i = 0; i < shape->dim_size(); i++) {
|
||||
if (shape->dim(i).size() < -1) {
|
||||
VLOG(2) << "Normalizing dimension: " << i << " from "
|
||||
<< shape->dim(i).size() << " to -1";
|
||||
shape->mutable_dim(i)->set_size(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Processes symbolic shapes.
|
||||
// Each symbolic shape or dimension is represented by a handle. Unlike the TF
|
||||
// shape refiner which creates new handles every time it processes an unknown
|
||||
@ -722,7 +737,8 @@ class SymbolicShapeRefiner {
|
||||
return it->second.inference_context.get();
|
||||
}
|
||||
|
||||
// Forward the shapes from the function input nodes to
|
||||
// Forward the shapes from the function input nodes, PartitionedCalls or
|
||||
// StatefulPartitionedCall to
|
||||
// the argument nodes (which are Placeholder nodes), then
|
||||
// perform shape inference on the function body.
|
||||
//
|
||||
@ -732,10 +748,12 @@ class SymbolicShapeRefiner {
|
||||
// In the event of an error, UpdateNode will simply set `function_node`'s
|
||||
// output shape to be Unknown.
|
||||
Status UpdateFunction(const NodeDef* function_node) {
|
||||
auto it = fun_to_grappler_function_item_.find(function_node->op());
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*function_node, &function));
|
||||
auto it = fun_to_grappler_function_item_.find(function.name());
|
||||
if (it == fun_to_grappler_function_item_.end()) {
|
||||
return errors::InvalidArgument(
|
||||
function_node->op(),
|
||||
function.name(),
|
||||
" was not previously added to SymbolicShapeRefiner.");
|
||||
}
|
||||
|
||||
@ -743,7 +761,7 @@ class SymbolicShapeRefiner {
|
||||
it->second;
|
||||
if (!maybe_grappler_function_item.has_value()) {
|
||||
VLOG(3) << "Skip failed to instantiate function call: function_name="
|
||||
<< function_node->op();
|
||||
<< function.name();
|
||||
|
||||
auto* ctx = GetNodeContext(function_node);
|
||||
auto* ic = ctx->inference_context.get();
|
||||
@ -789,11 +807,7 @@ class SymbolicShapeRefiner {
|
||||
const auto& handle = input_ic->output(output_port_num);
|
||||
input_ic->ShapeHandleToProto(handle, &proto);
|
||||
// There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
|
||||
for (int i = 0; i < proto.dim_size(); i++) {
|
||||
if (proto.dim(i).size() < -1) {
|
||||
proto.mutable_dim(i)->set_size(-1);
|
||||
}
|
||||
}
|
||||
NormalizeShapeForOutput(&proto);
|
||||
|
||||
AttrValue output_attr;
|
||||
output_attr.mutable_list()->add_shape()->Swap(&proto);
|
||||
@ -870,8 +884,9 @@ class SymbolicShapeRefiner {
|
||||
out_tensor.ToString(), " has invalid position ", out_tensor.index(),
|
||||
" (output_properties.size() = ", output_properties.size(), ").");
|
||||
}
|
||||
auto const& outprop = output_properties[out_tensor.index()];
|
||||
const TensorShapeProto& shape = outprop.shape();
|
||||
auto& outprop = output_properties[out_tensor.index()];
|
||||
TensorShapeProto shape = outprop.shape();
|
||||
NormalizeShapeForOutput(&shape);
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
|
||||
ic->set_output(output, out);
|
||||
@ -1196,15 +1211,14 @@ class SymbolicShapeRefiner {
|
||||
return true;
|
||||
}
|
||||
|
||||
Status AddFunction(const NodeDef* function_node) {
|
||||
auto it = fun_to_grappler_function_item_.find(function_node->op());
|
||||
Status AddFunction(const NodeDef* function_node, NameAttrList function) {
|
||||
auto it = fun_to_grappler_function_item_.find(function.name());
|
||||
if (it != fun_to_grappler_function_item_.end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const FunctionDef* function_def =
|
||||
CHECK_NOTNULL(function_library_.Find(function_node->op()));
|
||||
|
||||
CHECK_NOTNULL(function_library_.Find(function.name()));
|
||||
GrapplerFunctionItem grappler_function_item;
|
||||
Status function_instantiated =
|
||||
MakeGrapplerFunctionItem(*function_def, function_library_,
|
||||
@ -1242,10 +1256,15 @@ class SymbolicShapeRefiner {
|
||||
|
||||
Status AddNode(const NodeDef* node) {
|
||||
NodeContext& node_ctx = node_to_context_[node];
|
||||
TF_RETURN_IF_ERROR(function_library_.LookUp(node->op(), &node_ctx.op_data));
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*node, &function));
|
||||
|
||||
// For PartitionedCall, op_data represents the function info.
|
||||
TF_RETURN_IF_ERROR(
|
||||
function_library_.LookUp(function.name(), &node_ctx.op_data));
|
||||
|
||||
if (node_ctx.op_data->is_function_op) {
|
||||
TF_RETURN_IF_ERROR(AddFunction(node));
|
||||
TF_RETURN_IF_ERROR(AddFunction(node, function));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
|
||||
@ -2525,13 +2544,7 @@ Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def,
|
||||
TensorShapeProto* proto = attr_output_shape.mutable_list()->add_shape();
|
||||
*proto = tensor_property.shape();
|
||||
if (!allow_symbolic_shapes) {
|
||||
// There may be dim.size < -1 in SymbolicShapeRefiner. Change those to
|
||||
// -1.
|
||||
for (int i = 0; i < proto->dim_size(); i++) {
|
||||
if (proto->dim(i).size() < -1) {
|
||||
proto->mutable_dim(i)->set_size(-1);
|
||||
}
|
||||
}
|
||||
NormalizeShapeForOutput(proto);
|
||||
}
|
||||
}
|
||||
(*node->mutable_attr())["_output_shapes"] = attr_output_shape;
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
@ -2002,6 +2003,79 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
|
||||
EXPECT_EQ("float: [10,100]", PropToString(prop));
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, PartitionedCallOp) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
FunctionDefLibrary library;
|
||||
FunctionDef called_func = FunctionDefHelper::Create(
|
||||
"identity_function",
|
||||
/*in_def=*/{"arg0: int32"},
|
||||
/*out_def=*/{"ret0: int32"},
|
||||
/*attr_def=*/{},
|
||||
{{{"Identity"}, "Identity", {"arg0"}, {{"T", DT_INT32}}}},
|
||||
/*ret_def=*/{{"ret0", "Identity:output:0"}});
|
||||
*library.add_function() = called_func;
|
||||
TF_CHECK_OK(root.graph()->AddFunctionLibrary(library));
|
||||
|
||||
Output in = ops::Const(root, {3, 1, 2, 0});
|
||||
NameAttrList b_name_attr;
|
||||
b_name_attr.set_name("identity_function");
|
||||
ops::PartitionedCall call(root.WithOpName("identity_call"), {in}, {DT_INT32},
|
||||
b_name_attr);
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(root.ToGraphDef(&item.graph));
|
||||
|
||||
GraphProperties properties(item);
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/true,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/true));
|
||||
|
||||
EXPECT_EQ("int32: [4]",
|
||||
PropToString(properties.GetOutputProperties("identity_call")[0]));
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, NonTrivialInputPartitionedCallOp) {
|
||||
auto f = FunctionDefHelper::Create(
|
||||
// Name
|
||||
"FunctionWhichAdds",
|
||||
// Inputs
|
||||
{"arg0: int32", "arg1: int32"},
|
||||
// Outputs
|
||||
{"ret0: int32"},
|
||||
/*attr_def=*/{},
|
||||
// Nodes
|
||||
{{{"a"}, "Add", {"arg0", "arg1"}, {{"T", DT_INT32}}}},
|
||||
/*ret_def=*/{{"ret0", "a:z:0"}});
|
||||
|
||||
FunctionDefLibrary function_lib;
|
||||
function_lib.add_function()->Swap(&f);
|
||||
tensorflow::Scope root = tensorflow::Scope::NewRootScope();
|
||||
TF_CHECK_OK(root.graph()->AddFunctionLibrary(function_lib));
|
||||
|
||||
PartialTensorShape input_shape({2, 2, -1});
|
||||
Output in1 =
|
||||
ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
|
||||
Output in2 =
|
||||
ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
|
||||
NameAttrList b_name_attr;
|
||||
b_name_attr.set_name("FunctionWhichAdds");
|
||||
ops::PartitionedCall call(root.WithOpName("add_call"), {in1, in2}, {DT_INT32},
|
||||
b_name_attr);
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(root.ToGraphDef(&item.graph));
|
||||
|
||||
GraphProperties properties(item);
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/true,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/true));
|
||||
|
||||
EXPECT_EQ("int32: [2,2,-1]",
|
||||
PropToString(properties.GetOutputProperties("add_call")[0]));
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, ShapeAnnotatedFunctionOp) {
|
||||
// A function, which we cannot infer output shape statically.
|
||||
auto f = FunctionDefHelper::Create(
|
||||
|
@ -43,7 +43,7 @@ class ScopedAllocatorOptimizer : public GraphOptimizer {
|
||||
|
||||
string name() const override { return "scoped_allocator_optimizer"; }
|
||||
|
||||
bool UsesFunctionLibrary() const override { return false; }
|
||||
bool UsesFunctionLibrary() const override { return true; }
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) override;
|
||||
|
Loading…
Reference in New Issue
Block a user