Use annotated output shapes in shape inference

PiperOrigin-RevId: 236431154
This commit is contained in:
Andiry Xu 2019-03-01 23:16:52 -08:00 committed by TensorFlower Gardener
parent 386165b089
commit f832595165
4 changed files with 232 additions and 14 deletions

View File

@ -979,6 +979,41 @@ class SymbolicShapeRefiner {
return true; return true;
} }
// Return true if the annotated shape is compatible with shape inference
// result. Examples:
// Inferred shape: ?, annotated shape: [10, 10] -> true;
// Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
// Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
// Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
bool CompatibleShapes(ShapeHandle inferred_shape,
ShapeHandle annotated_shape) const {
if (inferred_shape.SameHandle(annotated_shape)) {
return true;
}
if (!InferenceContext::RankKnown(inferred_shape)) {
return true;
}
if (InferenceContext::Rank(inferred_shape) !=
InferenceContext::Rank(annotated_shape)) {
return false;
}
const int rank = InferenceContext::Rank(inferred_shape);
for (int i = 0; i < rank; ++i) {
if (!InferenceContext::DimKnownRank(inferred_shape, i)
.SameHandle(
InferenceContext::DimKnownRank(annotated_shape, i))) {
int64 val1 = InferenceContext::Value(
InferenceContext::DimKnownRank(inferred_shape, i));
int64 val2 = InferenceContext::Value(
InferenceContext::DimKnownRank(annotated_shape, i));
if (val1 >= 0 && val1 != val2) {
return false;
}
}
}
return true;
}
bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1, bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
const std::vector<ShapeAndType>& st2) const { const std::vector<ShapeAndType>& st2) const {
if (st1.size() != st2.size()) { if (st1.size() != st2.size()) {
@ -1139,9 +1174,9 @@ class SymbolicShapeRefiner {
return true; return true;
} }
// Returns true if we want to update output values with running EvaluateNode() // Returns true if we want to update output shapes and values with running
// for this op, based on op type, data type, and size. // EvaluateNode() for this op, based on op type, data type, and size.
bool ShouldUpdateOutputValues(NodeContext* c, int64 max_size) { bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64 max_size) {
InferenceContext* ic = c->inference_context.get(); InferenceContext* ic = c->inference_context.get();
// Due to the cost of running EvaluateNode(), we limit only to white listed // Due to the cost of running EvaluateNode(), we limit only to white listed
@ -1232,8 +1267,9 @@ class SymbolicShapeRefiner {
} }
} }
// Run a node to infer output values, and add it to the NodeContext. // Run a node to infer output shapes and values, and add it to the
Status UpdateOutputValues(const NodeDef& node, NodeContext* c) { // NodeContext.
Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
InferenceContext* ic = c->inference_context.get(); InferenceContext* ic = c->inference_context.get();
// Input to EvaluateNode() // Input to EvaluateNode()
@ -1264,7 +1300,7 @@ class SymbolicShapeRefiner {
ic->MakeShapeFromTensorShape(t->shape(), &output_shape)); ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
if (ic->FullyDefined(ic->output(k)) && if (ic->FullyDefined(ic->output(k)) &&
!EquivalentShapes(ic->output(k), output_shape)) { !EquivalentShapes(ic->output(k), output_shape)) {
LOG(WARNING) << "UpdateOutputValues() -- node: " << node.name() LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
<< ", inferred output shape " << ", inferred output shape "
<< "doesn't match for k=" << k << ": " << "doesn't match for k=" << k << ": "
<< "ic->output(k): " << ic->DebugString(ic->output(k)) << "ic->output(k): " << ic->DebugString(ic->output(k))
@ -1284,6 +1320,54 @@ class SymbolicShapeRefiner {
return Status::OK(); return Status::OK();
} }
// Update output shapes with annotated information.
// Currently only handle nodes with static shapes, i.e. shapes do not change
// during execution.
// TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
NodeContext* c) const {
const auto& attr = node.attr();
if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
attr.count(kOutputShapes) == 0)
return Status::OK();
InferenceContext* ic = c->inference_context.get();
int output_size = attr.at(kOutputShapes).list().shape_size();
for (int i = 0; i < ic->num_outputs(); i++) {
// Annotated Switch node has only one output. Propagate the shape to all
// the outputs.
int shape_index = IsSwitch(node) ? 0 : i;
if (shape_index >= output_size) {
LOG(WARNING)
<< "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
<< node.name() << ", inferred output shape size "
<< ic->num_outputs() << ", annotated output shape size "
<< output_size;
break;
}
const TensorShapeProto& shape =
attr.at(kOutputShapes).list().shape(shape_index);
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
// Only use annotated shapes if the inference shape is unknown and
// compatible with annotated shapes.
if (!ic->FullyDefined(ic->output(i)) &&
CompatibleShapes(ic->output(i), output_shape)) {
VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
<< node.name() << ", inferred output shape " << i << ": "
<< "ic->output(i): " << ic->DebugString(ic->output(i))
<< ", annotated output shape: " << ic->DebugString(output_shape)
<< " -- " << node.ShortDebugString();
ic->set_output(i, output_shape);
}
}
return Status::OK();
}
Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed, Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
NodeContext* c) { NodeContext* c) {
// Propagate tensors and shape tensors unless the node is fed. // Propagate tensors and shape tensors unless the node is fed.
@ -1476,16 +1560,19 @@ class SymbolicShapeRefiner {
} }
if (aggressive_shape_inference_) { if (aggressive_shape_inference_) {
// Update output shapes with annotated information. This is optional.
UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
// Update output tensor values using EvaluateNode() if we can. // Update output tensor values using EvaluateNode() if we can.
// Due to the cost of EvaluateNode(), we run it only for certain op types // Due to the cost of EvaluateNode(), we run it only for certain op types
// (white listed) and small integer tensors. // (white listed) and small integer tensors.
const int max_element_size = 17; // Max up to 4x4 matrix or similar. const int max_element_size = 17; // Max up to 4x4 matrix or similar.
if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) || if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
!ShouldUpdateOutputValues(c, max_element_size)) { !ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
return Status::OK(); return Status::OK();
} }
UpdateOutputValues(node, c).IgnoreError(); // This is optional. UpdateOutputShapesAndValues(node, c).IgnoreError(); // This is optional.
} }
return Status::OK(); return Status::OK();
} }
@ -1797,6 +1884,7 @@ Status GraphProperties::UpdateShapes(
// UpdateNode calls UpdateFunction if a function node is detected. // UpdateNode calls UpdateFunction if a function node is detected.
TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes)); TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
} }
return Status::OK(); return Status::OK();
} }

View File

@ -27,6 +27,45 @@ namespace tensorflow {
namespace grappler { namespace grappler {
// Optional attributes that tell about node output information.
// We use these side information, if provided, for static shape inference
// and VirtualScheduler scheduling.
// Switch op attribute as a vector of int that tells which branch the
// Switch output is taken on every round of execution.
// Used for scheduling ops after Switch correctly (e.g., While loop).
ABSL_CONST_INIT const char kOutputSlots[] = "_output_slot_vector";
// Example:
// Assume a node has two outputs and iterated for three times. Then it has:
// _execution_count = 3
// _output_sizes_vector = [2, 2, 2]
// _output_dtype_vector.size = 6
// _output_shape_vector.size = 6
// If all the iterations have same output shapes, then
// _execution_count = 3
// _same_output_for_iterations = true
// _output_sizes_vector = [2]
// _output_dtype_vector.size = 2
// _output_shape_vector.size = 2
// How many times this node has been executed.
ABSL_CONST_INIT const char kExecutionCount[] = "_execution_count";
// Records the output sizes for each round of execution.
ABSL_CONST_INIT const char kOutputSizes[] = "_output_sizes_vector";
// The node has been scheduled multiple times with outputs that have the same
// shape.
ABSL_CONST_INIT const char kOutputSame[] = "_same_output_for_iterations";
// Outputs DataType vector.
ABSL_CONST_INIT const char kOutputTypes[] = "_output_dtype_vector";
// Outputs TensorShapeProto vector.
ABSL_CONST_INIT const char kOutputShapes[] = "_output_shape_vector";
class SymbolicShapeRefiner; class SymbolicShapeRefiner;
class TopoQueue; class TopoQueue;

View File

@ -1793,6 +1793,103 @@ TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value()); ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value());
} }
TEST_F(GraphPropertiesTest, ShapeAnnotation) {
GrapplerItem item;
TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
.Attr("dtype", DT_FLOAT)
.Attr("shape", PartialTensorShape({-1, -1}))
.Finalize(item.graph.add_node()));
// Annotate shapes.
TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
.Attr("dtype", DT_FLOAT)
.Attr("_same_output_for_iterations", true)
.Attr("_output_shape_vector", {TensorShape({5, 7})})
.Input("Input", 0, DT_FLOAT)
.Finalize(item.graph.add_node()));
{
GraphProperties properties(item);
// Without aggressive_shape_inference, ignore annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false));
const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
EXPECT_EQ(2, prop.shape().dim_size());
// Get unknown shapes without using annotated information.
EXPECT_EQ("float: [-1,-1]", PropToString(prop));
}
{
GraphProperties properties(item);
// Use annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
EXPECT_EQ(2, prop.shape().dim_size());
// Update output shape using annotated shapes.
EXPECT_EQ("float: [5,7]", PropToString(prop));
}
}
TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
GrapplerItem item;
TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
.Attr("dtype", DT_FLOAT)
.Attr("shape", PartialTensorShape({-1, 100}))
.Finalize(item.graph.add_node()));
// Annotate shapes.
TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
.Attr("dtype", DT_FLOAT)
.Attr("_same_output_for_iterations", true)
.Attr("_output_shape_vector", {TensorShape({10, 100})})
.Input("Input", 0, DT_FLOAT)
.Finalize(item.graph.add_node()));
GraphProperties properties(item);
// Use annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
EXPECT_EQ(2, prop.shape().dim_size());
// Compatible shapes. Update output shape using annotated shapes.
EXPECT_EQ("float: [10,100]", PropToString(prop));
}
TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
GrapplerItem item;
TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
.Attr("dtype", DT_FLOAT)
.Attr("shape", PartialTensorShape({-1, 100}))
.Finalize(item.graph.add_node()));
// Annotate shapes.
TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
.Attr("dtype", DT_FLOAT)
.Attr("_same_output_for_iterations", true)
.Attr("_output_shape_vector", {TensorShape({10, 10})})
.Input("Input", 0, DT_FLOAT)
.Finalize(item.graph.add_node()));
GraphProperties properties(item);
// Use annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
EXPECT_EQ(2, prop.shape().dim_size());
// Incompatible shapes. Do not use annotated shapes.
EXPECT_EQ("float: [-1,100]", PropToString(prop));
}
} // namespace } // namespace
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow

View File

@ -36,12 +36,6 @@ namespace tensorflow {
namespace grappler { namespace grappler {
namespace { namespace {
// Optional attribute name for Switch op as a vector of int that tells
// which branch the Switch output is taken on every round of execution.
// We use this side information, if provided, for scheduling ops after Switch
// correctly (e.g., While loop).
constexpr char kOutputSlots[] = "_output_slot_vector";
Costs CombineCosts(const Costs& left, const Costs& right) { Costs CombineCosts(const Costs& left, const Costs& right) {
CHECK_NE(left.max_memory, kMemoryUnknown); CHECK_NE(left.max_memory, kMemoryUnknown);
CHECK_NE(left.max_per_op_buffers, kMemoryUnknown); CHECK_NE(left.max_per_op_buffers, kMemoryUnknown);