Use annotated output shapes in shape inference
PiperOrigin-RevId: 236431154
This commit is contained in:
parent
386165b089
commit
f832595165
@ -979,6 +979,41 @@ class SymbolicShapeRefiner {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Return true if the annotated shape is compatible with shape inference
|
||||
// result. Examples:
|
||||
// Inferred shape: ?, annotated shape: [10, 10] -> true;
|
||||
// Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
|
||||
// Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
|
||||
// Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
|
||||
bool CompatibleShapes(ShapeHandle inferred_shape,
|
||||
ShapeHandle annotated_shape) const {
|
||||
if (inferred_shape.SameHandle(annotated_shape)) {
|
||||
return true;
|
||||
}
|
||||
if (!InferenceContext::RankKnown(inferred_shape)) {
|
||||
return true;
|
||||
}
|
||||
if (InferenceContext::Rank(inferred_shape) !=
|
||||
InferenceContext::Rank(annotated_shape)) {
|
||||
return false;
|
||||
}
|
||||
const int rank = InferenceContext::Rank(inferred_shape);
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
if (!InferenceContext::DimKnownRank(inferred_shape, i)
|
||||
.SameHandle(
|
||||
InferenceContext::DimKnownRank(annotated_shape, i))) {
|
||||
int64 val1 = InferenceContext::Value(
|
||||
InferenceContext::DimKnownRank(inferred_shape, i));
|
||||
int64 val2 = InferenceContext::Value(
|
||||
InferenceContext::DimKnownRank(annotated_shape, i));
|
||||
if (val1 >= 0 && val1 != val2) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
|
||||
const std::vector<ShapeAndType>& st2) const {
|
||||
if (st1.size() != st2.size()) {
|
||||
@ -1139,9 +1174,9 @@ class SymbolicShapeRefiner {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if we want to update output values with running EvaluateNode()
|
||||
// for this op, based on op type, data type, and size.
|
||||
bool ShouldUpdateOutputValues(NodeContext* c, int64 max_size) {
|
||||
// Returns true if we want to update output shapes and values with running
|
||||
// EvaluateNode() for this op, based on op type, data type, and size.
|
||||
bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64 max_size) {
|
||||
InferenceContext* ic = c->inference_context.get();
|
||||
|
||||
// Due to the cost of running EvaluateNode(), we limit only to white listed
|
||||
@ -1232,8 +1267,9 @@ class SymbolicShapeRefiner {
|
||||
}
|
||||
}
|
||||
|
||||
// Run a node to infer output values, and add it to the NodeContext.
|
||||
Status UpdateOutputValues(const NodeDef& node, NodeContext* c) {
|
||||
// Run a node to infer output shapes and values, and add it to the
|
||||
// NodeContext.
|
||||
Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
|
||||
InferenceContext* ic = c->inference_context.get();
|
||||
|
||||
// Input to EvaluateNode()
|
||||
@ -1264,7 +1300,7 @@ class SymbolicShapeRefiner {
|
||||
ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
|
||||
if (ic->FullyDefined(ic->output(k)) &&
|
||||
!EquivalentShapes(ic->output(k), output_shape)) {
|
||||
LOG(WARNING) << "UpdateOutputValues() -- node: " << node.name()
|
||||
LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
|
||||
<< ", inferred output shape "
|
||||
<< "doesn't match for k=" << k << ": "
|
||||
<< "ic->output(k): " << ic->DebugString(ic->output(k))
|
||||
@ -1284,6 +1320,54 @@ class SymbolicShapeRefiner {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Update output shapes with annotated information.
|
||||
// Currently only handle nodes with static shapes, i.e. shapes do not change
|
||||
// during execution.
|
||||
// TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
|
||||
Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
|
||||
NodeContext* c) const {
|
||||
const auto& attr = node.attr();
|
||||
if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
|
||||
attr.count(kOutputShapes) == 0)
|
||||
return Status::OK();
|
||||
|
||||
InferenceContext* ic = c->inference_context.get();
|
||||
int output_size = attr.at(kOutputShapes).list().shape_size();
|
||||
|
||||
for (int i = 0; i < ic->num_outputs(); i++) {
|
||||
// Annotated Switch node has only one output. Propagate the shape to all
|
||||
// the outputs.
|
||||
int shape_index = IsSwitch(node) ? 0 : i;
|
||||
if (shape_index >= output_size) {
|
||||
LOG(WARNING)
|
||||
<< "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
|
||||
<< node.name() << ", inferred output shape size "
|
||||
<< ic->num_outputs() << ", annotated output shape size "
|
||||
<< output_size;
|
||||
break;
|
||||
}
|
||||
|
||||
const TensorShapeProto& shape =
|
||||
attr.at(kOutputShapes).list().shape(shape_index);
|
||||
ShapeHandle output_shape;
|
||||
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
|
||||
|
||||
// Only use annotated shapes if the inference shape is unknown and
|
||||
// compatible with annotated shapes.
|
||||
if (!ic->FullyDefined(ic->output(i)) &&
|
||||
CompatibleShapes(ic->output(i), output_shape)) {
|
||||
VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
|
||||
<< node.name() << ", inferred output shape " << i << ": "
|
||||
<< "ic->output(i): " << ic->DebugString(ic->output(i))
|
||||
<< ", annotated output shape: " << ic->DebugString(output_shape)
|
||||
<< " -- " << node.ShortDebugString();
|
||||
ic->set_output(i, output_shape);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
|
||||
NodeContext* c) {
|
||||
// Propagate tensors and shape tensors unless the node is fed.
|
||||
@ -1476,16 +1560,19 @@ class SymbolicShapeRefiner {
|
||||
}
|
||||
|
||||
if (aggressive_shape_inference_) {
|
||||
// Update output shapes with annotated information. This is optional.
|
||||
UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
|
||||
|
||||
// Update output tensor values using EvaluateNode() if we can.
|
||||
// Due to the cost of EvaluateNode(), we run it only for certain op types
|
||||
// (white listed) and small integer tensors.
|
||||
|
||||
const int max_element_size = 17; // Max up to 4x4 matrix or similar.
|
||||
if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
|
||||
!ShouldUpdateOutputValues(c, max_element_size)) {
|
||||
!ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
|
||||
return Status::OK();
|
||||
}
|
||||
UpdateOutputValues(node, c).IgnoreError(); // This is optional.
|
||||
UpdateOutputShapesAndValues(node, c).IgnoreError(); // This is optional.
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1797,6 +1884,7 @@ Status GraphProperties::UpdateShapes(
|
||||
// UpdateNode calls UpdateFunction if a function node is detected.
|
||||
TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -27,6 +27,45 @@ namespace tensorflow {
|
||||
|
||||
namespace grappler {
|
||||
|
||||
// Optional attributes that tell about node output information.
|
||||
// We use these side information, if provided, for static shape inference
|
||||
// and VirtualScheduler scheduling.
|
||||
|
||||
// Switch op attribute as a vector of int that tells which branch the
|
||||
// Switch output is taken on every round of execution.
|
||||
// Used for scheduling ops after Switch correctly (e.g., While loop).
|
||||
ABSL_CONST_INIT const char kOutputSlots[] = "_output_slot_vector";
|
||||
|
||||
// Example:
|
||||
// Assume a node has two outputs and iterated for three times. Then it has:
|
||||
// _execution_count = 3
|
||||
// _output_sizes_vector = [2, 2, 2]
|
||||
// _output_dtype_vector.size = 6
|
||||
// _output_shape_vector.size = 6
|
||||
|
||||
// If all the iterations have same output shapes, then
|
||||
// _execution_count = 3
|
||||
// _same_output_for_iterations = true
|
||||
// _output_sizes_vector = [2]
|
||||
// _output_dtype_vector.size = 2
|
||||
// _output_shape_vector.size = 2
|
||||
|
||||
// How many times this node has been executed.
|
||||
ABSL_CONST_INIT const char kExecutionCount[] = "_execution_count";
|
||||
|
||||
// Records the output sizes for each round of execution.
|
||||
ABSL_CONST_INIT const char kOutputSizes[] = "_output_sizes_vector";
|
||||
|
||||
// The node has been scheduled multiple times with outputs that have the same
|
||||
// shape.
|
||||
ABSL_CONST_INIT const char kOutputSame[] = "_same_output_for_iterations";
|
||||
|
||||
// Outputs DataType vector.
|
||||
ABSL_CONST_INIT const char kOutputTypes[] = "_output_dtype_vector";
|
||||
|
||||
// Outputs TensorShapeProto vector.
|
||||
ABSL_CONST_INIT const char kOutputShapes[] = "_output_shape_vector";
|
||||
|
||||
class SymbolicShapeRefiner;
|
||||
class TopoQueue;
|
||||
|
||||
|
@ -1793,6 +1793,103 @@ TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
|
||||
ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value());
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, ShapeAnnotation) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("shape", PartialTensorShape({-1, -1}))
|
||||
.Finalize(item.graph.add_node()));
|
||||
// Annotate shapes.
|
||||
TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("_same_output_for_iterations", true)
|
||||
.Attr("_output_shape_vector", {TensorShape({5, 7})})
|
||||
.Input("Input", 0, DT_FLOAT)
|
||||
.Finalize(item.graph.add_node()));
|
||||
{
|
||||
GraphProperties properties(item);
|
||||
// Without aggressive_shape_inference, ignore annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/false));
|
||||
const auto props = properties.GetOutputProperties("Identity");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_FLOAT, prop.dtype());
|
||||
EXPECT_EQ(2, prop.shape().dim_size());
|
||||
// Get unknown shapes without using annotated information.
|
||||
EXPECT_EQ("float: [-1,-1]", PropToString(prop));
|
||||
}
|
||||
{
|
||||
GraphProperties properties(item);
|
||||
// Use annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
const auto props = properties.GetOutputProperties("Identity");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_FLOAT, prop.dtype());
|
||||
EXPECT_EQ(2, prop.shape().dim_size());
|
||||
// Update output shape using annotated shapes.
|
||||
EXPECT_EQ("float: [5,7]", PropToString(prop));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("shape", PartialTensorShape({-1, 100}))
|
||||
.Finalize(item.graph.add_node()));
|
||||
// Annotate shapes.
|
||||
TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("_same_output_for_iterations", true)
|
||||
.Attr("_output_shape_vector", {TensorShape({10, 100})})
|
||||
.Input("Input", 0, DT_FLOAT)
|
||||
.Finalize(item.graph.add_node()));
|
||||
GraphProperties properties(item);
|
||||
// Use annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
const auto props = properties.GetOutputProperties("Identity");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_FLOAT, prop.dtype());
|
||||
EXPECT_EQ(2, prop.shape().dim_size());
|
||||
// Compatible shapes. Update output shape using annotated shapes.
|
||||
EXPECT_EQ("float: [10,100]", PropToString(prop));
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("shape", PartialTensorShape({-1, 100}))
|
||||
.Finalize(item.graph.add_node()));
|
||||
// Annotate shapes.
|
||||
TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("_same_output_for_iterations", true)
|
||||
.Attr("_output_shape_vector", {TensorShape({10, 10})})
|
||||
.Input("Input", 0, DT_FLOAT)
|
||||
.Finalize(item.graph.add_node()));
|
||||
GraphProperties properties(item);
|
||||
// Use annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
const auto props = properties.GetOutputProperties("Identity");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_FLOAT, prop.dtype());
|
||||
EXPECT_EQ(2, prop.shape().dim_size());
|
||||
// Incompatible shapes. Do not use annotated shapes.
|
||||
EXPECT_EQ("float: [-1,100]", PropToString(prop));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -36,12 +36,6 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
// Optional attribute name for Switch op as a vector of int that tells
|
||||
// which branch the Switch output is taken on every round of execution.
|
||||
// We use this side information, if provided, for scheduling ops after Switch
|
||||
// correctly (e.g., While loop).
|
||||
constexpr char kOutputSlots[] = "_output_slot_vector";
|
||||
|
||||
Costs CombineCosts(const Costs& left, const Costs& right) {
|
||||
CHECK_NE(left.max_memory, kMemoryUnknown);
|
||||
CHECK_NE(left.max_per_op_buffers, kMemoryUnknown);
|
||||
|
Loading…
Reference in New Issue
Block a user