diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 6907988d08f..1edb4ef4655 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -979,6 +979,41 @@ class SymbolicShapeRefiner { return true; } + // Return true if the annotated shape is compatible with shape inference + // result. Examples: + // Inferred shape: ?, annotated shape: [10, 10] -> true; + // Inferred shape: [-1, 10], annotated shape: [10, 10] -> true; + // Inferred shape: [-1, 100], annotated shape: [10, 10] -> false; + // Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false. + bool CompatibleShapes(ShapeHandle inferred_shape, + ShapeHandle annotated_shape) const { + if (inferred_shape.SameHandle(annotated_shape)) { + return true; + } + if (!InferenceContext::RankKnown(inferred_shape)) { + return true; + } + if (InferenceContext::Rank(inferred_shape) != + InferenceContext::Rank(annotated_shape)) { + return false; + } + const int rank = InferenceContext::Rank(inferred_shape); + for (int i = 0; i < rank; ++i) { + if (!InferenceContext::DimKnownRank(inferred_shape, i) + .SameHandle( + InferenceContext::DimKnownRank(annotated_shape, i))) { + int64 val1 = InferenceContext::Value( + InferenceContext::DimKnownRank(inferred_shape, i)); + int64 val2 = InferenceContext::Value( + InferenceContext::DimKnownRank(annotated_shape, i)); + if (val1 >= 0 && val1 != val2) { + return false; + } + } + } + return true; + } + bool EquivalentShapesAndTypes(const std::vector& st1, const std::vector& st2) const { if (st1.size() != st2.size()) { @@ -1139,9 +1174,9 @@ class SymbolicShapeRefiner { return true; } - // Returns true if we want to update output values with running EvaluateNode() - // for this op, based on op type, data type, and size. - bool ShouldUpdateOutputValues(NodeContext* c, int64 max_size) { + // Returns true if we want to update output shapes and values with running + // EvaluateNode() for this op, based on op type, data type, and size. + bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64 max_size) { InferenceContext* ic = c->inference_context.get(); // Due to the cost of running EvaluateNode(), we limit only to white listed @@ -1232,8 +1267,9 @@ class SymbolicShapeRefiner { } } - // Run a node to infer output values, and add it to the NodeContext. - Status UpdateOutputValues(const NodeDef& node, NodeContext* c) { + // Run a node to infer output shapes and values, and add it to the + // NodeContext. + Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) { InferenceContext* ic = c->inference_context.get(); // Input to EvaluateNode() @@ -1264,7 +1300,7 @@ class SymbolicShapeRefiner { ic->MakeShapeFromTensorShape(t->shape(), &output_shape)); if (ic->FullyDefined(ic->output(k)) && !EquivalentShapes(ic->output(k), output_shape)) { - LOG(WARNING) << "UpdateOutputValues() -- node: " << node.name() + LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name() << ", inferred output shape " << "doesn't match for k=" << k << ": " << "ic->output(k): " << ic->DebugString(ic->output(k)) @@ -1284,6 +1320,54 @@ class SymbolicShapeRefiner { return Status::OK(); } + // Update output shapes with annotated information. + // Currently only handle nodes with static shapes, i.e. shapes do not change + // during execution. + // TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well. + Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node, + NodeContext* c) const { + const auto& attr = node.attr(); + if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() || + attr.count(kOutputShapes) == 0) + return Status::OK(); + + InferenceContext* ic = c->inference_context.get(); + int output_size = attr.at(kOutputShapes).list().shape_size(); + + for (int i = 0; i < ic->num_outputs(); i++) { + // Annotated Switch node has only one output. Propagate the shape to all + // the outputs. + int shape_index = IsSwitch(node) ? 0 : i; + if (shape_index >= output_size) { + LOG(WARNING) + << "UpdateOutputShapesUsingAnnotatedInformation() -- node: " + << node.name() << ", inferred output shape size " + << ic->num_outputs() << ", annotated output shape size " + << output_size; + break; + } + + const TensorShapeProto& shape = + attr.at(kOutputShapes).list().shape(shape_index); + ShapeHandle output_shape; + TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape)); + + // Only use annotated shapes if the inference shape is unknown and + // compatible with annotated shapes. + if (!ic->FullyDefined(ic->output(i)) && + CompatibleShapes(ic->output(i), output_shape)) { + VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: " + << node.name() << ", inferred output shape " << i << ": " + << "ic->output(i): " << ic->DebugString(ic->output(i)) + << ", annotated output shape: " << ic->DebugString(output_shape) + << " -- " << node.ShortDebugString(); + ic->set_output(i, output_shape); + } + } + + return Status::OK(); + } + Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed, NodeContext* c) { // Propagate tensors and shape tensors unless the node is fed. @@ -1476,16 +1560,19 @@ class SymbolicShapeRefiner { } if (aggressive_shape_inference_) { + // Update output shapes with annotated information. This is optional. + UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError(); + // Update output tensor values using EvaluateNode() if we can. // Due to the cost of EvaluateNode(), we run it only for certain op types // (white listed) and small integer tensors. const int max_element_size = 17; // Max up to 4x4 matrix or similar. if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) || - !ShouldUpdateOutputValues(c, max_element_size)) { + !ShouldUpdateOutputShapesAndValues(c, max_element_size)) { return Status::OK(); } - UpdateOutputValues(node, c).IgnoreError(); // This is optional. + UpdateOutputShapesAndValues(node, c).IgnoreError(); // This is optional. } return Status::OK(); } @@ -1797,6 +1884,7 @@ Status GraphProperties::UpdateShapes( // UpdateNode calls UpdateFunction if a function node is detected. TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes)); } + return Status::OK(); } diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index 3fcad6eb1b1..bb7e6ed16a6 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -27,6 +27,45 @@ namespace tensorflow { namespace grappler { +// Optional attributes that tell about node output information. +// We use these side information, if provided, for static shape inference +// and VirtualScheduler scheduling. + +// Switch op attribute as a vector of int that tells which branch the +// Switch output is taken on every round of execution. +// Used for scheduling ops after Switch correctly (e.g., While loop). +ABSL_CONST_INIT const char kOutputSlots[] = "_output_slot_vector"; + +// Example: +// Assume a node has two outputs and iterated for three times. Then it has: +// _execution_count = 3 +// _output_sizes_vector = [2, 2, 2] +// _output_dtype_vector.size = 6 +// _output_shape_vector.size = 6 + +// If all the iterations have same output shapes, then +// _execution_count = 3 +// _same_output_for_iterations = true +// _output_sizes_vector = [2] +// _output_dtype_vector.size = 2 +// _output_shape_vector.size = 2 + +// How many times this node has been executed. +ABSL_CONST_INIT const char kExecutionCount[] = "_execution_count"; + +// Records the output sizes for each round of execution. +ABSL_CONST_INIT const char kOutputSizes[] = "_output_sizes_vector"; + +// The node has been scheduled multiple times with outputs that have the same +// shape. +ABSL_CONST_INIT const char kOutputSame[] = "_same_output_for_iterations"; + +// Outputs DataType vector. +ABSL_CONST_INIT const char kOutputTypes[] = "_output_dtype_vector"; + +// Outputs TensorShapeProto vector. +ABSL_CONST_INIT const char kOutputShapes[] = "_output_shape_vector"; + class SymbolicShapeRefiner; class TopoQueue; diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index ce8d367a34f..6c37d2418ae 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -1793,6 +1793,103 @@ TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) { ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value()); } +TEST_F(GraphPropertiesTest, ShapeAnnotation) { + GrapplerItem item; + TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder") + .Attr("dtype", DT_FLOAT) + .Attr("shape", PartialTensorShape({-1, -1})) + .Finalize(item.graph.add_node())); + // Annotate shapes. + TF_CHECK_OK(NodeDefBuilder("Identity", "Identity") + .Attr("dtype", DT_FLOAT) + .Attr("_same_output_for_iterations", true) + .Attr("_output_shape_vector", {TensorShape({5, 7})}) + .Input("Input", 0, DT_FLOAT) + .Finalize(item.graph.add_node())); + { + GraphProperties properties(item); + // Without aggressive_shape_inference, ignore annotated information. + TF_CHECK_OK(properties.InferStatically( + /*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/false)); + const auto props = properties.GetOutputProperties("Identity"); + EXPECT_EQ(1, props.size()); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_EQ(2, prop.shape().dim_size()); + // Get unknown shapes without using annotated information. + EXPECT_EQ("float: [-1,-1]", PropToString(prop)); + } + { + GraphProperties properties(item); + // Use annotated information. + TF_CHECK_OK(properties.InferStatically( + /*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/true)); + const auto props = properties.GetOutputProperties("Identity"); + EXPECT_EQ(1, props.size()); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_EQ(2, prop.shape().dim_size()); + // Update output shape using annotated shapes. + EXPECT_EQ("float: [5,7]", PropToString(prop)); + } +} + +TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) { + GrapplerItem item; + TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder") + .Attr("dtype", DT_FLOAT) + .Attr("shape", PartialTensorShape({-1, 100})) + .Finalize(item.graph.add_node())); + // Annotate shapes. + TF_CHECK_OK(NodeDefBuilder("Identity", "Identity") + .Attr("dtype", DT_FLOAT) + .Attr("_same_output_for_iterations", true) + .Attr("_output_shape_vector", {TensorShape({10, 100})}) + .Input("Input", 0, DT_FLOAT) + .Finalize(item.graph.add_node())); + GraphProperties properties(item); + // Use annotated information. + TF_CHECK_OK(properties.InferStatically( + /*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/true)); + const auto props = properties.GetOutputProperties("Identity"); + EXPECT_EQ(1, props.size()); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_EQ(2, prop.shape().dim_size()); + // Compatible shapes. Update output shape using annotated shapes. + EXPECT_EQ("float: [10,100]", PropToString(prop)); +} + +TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) { + GrapplerItem item; + TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder") + .Attr("dtype", DT_FLOAT) + .Attr("shape", PartialTensorShape({-1, 100})) + .Finalize(item.graph.add_node())); + // Annotate shapes. + TF_CHECK_OK(NodeDefBuilder("Identity", "Identity") + .Attr("dtype", DT_FLOAT) + .Attr("_same_output_for_iterations", true) + .Attr("_output_shape_vector", {TensorShape({10, 10})}) + .Input("Input", 0, DT_FLOAT) + .Finalize(item.graph.add_node())); + GraphProperties properties(item); + // Use annotated information. + TF_CHECK_OK(properties.InferStatically( + /*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/true)); + const auto props = properties.GetOutputProperties("Identity"); + EXPECT_EQ(1, props.size()); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_EQ(2, prop.shape().dim_size()); + // Incompatible shapes. Do not use annotated shapes. + EXPECT_EQ("float: [-1,100]", PropToString(prop)); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 0aac0348b51..606f03727a8 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -36,12 +36,6 @@ namespace tensorflow { namespace grappler { namespace { -// Optional attribute name for Switch op as a vector of int that tells -// which branch the Switch output is taken on every round of execution. -// We use this side information, if provided, for scheduling ops after Switch -// correctly (e.g., While loop). -constexpr char kOutputSlots[] = "_output_slot_vector"; - Costs CombineCosts(const Costs& left, const Costs& right) { CHECK_NE(left.max_memory, kMemoryUnknown); CHECK_NE(left.max_per_op_buffers, kMemoryUnknown);