From 10451eb6cfe67a8277c39a2fd7848fbbef706f10 Mon Sep 17 00:00:00 2001 From: Vijay Vasudevan Date: Fri, 2 Sep 2016 11:56:56 -0800 Subject: [PATCH] TensorFlow: Add "SetShape" function to ShapeRefiner. This will be used to eventually implement the same functionality as python's 'set_shape' function, once we add the ShapeRefiner to the C API. Change: 132093966 --- tensorflow/core/graph/shape_refiner.cc | 30 +++++++++++++++++++++ tensorflow/core/graph/shape_refiner.h | 8 ++++++ tensorflow/core/graph/shape_refiner_test.cc | 27 +++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/tensorflow/core/graph/shape_refiner.cc b/tensorflow/core/graph/shape_refiner.cc index e45e4e0d633..66a5202a147 100644 --- a/tensorflow/core/graph/shape_refiner.cc +++ b/tensorflow/core/graph/shape_refiner.cc @@ -113,6 +113,36 @@ Status ShapeRefiner::AddNode(const Node* node) { return Status::OK(); } +Status ShapeRefiner::SetShape(const Node* node, int output_port, + shape_inference::ShapeHandle shape) { + auto c = GetContext(node); + if (c == nullptr) { + return errors::Internal("Could not find context for ", node->name()); + } + + if (output_port < 0 || output_port >= node->num_outputs()) { + return errors::InvalidArgument( + "output_port '", output_port, "' is out of range, ", "node '", + node->name(), "' has ", node->num_outputs(), " outputs"); + } + + // Check compatibility + shape_inference::ShapeHandle existing_shape = c->output(output_port); + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &unused)); + + c->set_output(output_port, shape); + + // TODO(vrv): Do we need to propagate the new shape through all + // consumers that change their outputs? At the moment, python + // does not do this, but this seems like a nice feature. + + // TODO(vrv): We might need to keep track of the fact that the + // existing shape is invalidated, in case we need to propagate + // this information to remote workers. + return Status::OK(); +} + Status ShapeRefiner::ConstantValue(const Node* node, Tensor* tensor_storage, const Tensor** input_tensor) const { *input_tensor = nullptr; diff --git a/tensorflow/core/graph/shape_refiner.h b/tensorflow/core/graph/shape_refiner.h index 21551903c03..63838e1cfdd 100644 --- a/tensorflow/core/graph/shape_refiner.h +++ b/tensorflow/core/graph/shape_refiner.h @@ -46,6 +46,14 @@ class ShapeRefiner { // - The shape inference function returns an error. Status AddNode(const Node* node); + // Sets 'node's 'output_port' output to have shape 'shape'. + // + // Returns an error if 'node' was not previously added to this + // object, if 'output_port' is invalid, or if 'shape' is + // not compatible with the existing shape of the output. + Status SetShape(const Node* node, int output_port, + shape_inference::ShapeHandle shape); + // Returns the InferenceContext for 'node', if present. shape_inference::InferenceContext* GetContext(const Node* node) const { auto it = node_to_context_.find(node); diff --git a/tensorflow/core/graph/shape_refiner_test.cc b/tensorflow/core/graph/shape_refiner_test.cc index 94cd6dc74a9..ac4cf94546c 100644 --- a/tensorflow/core/graph/shape_refiner_test.cc +++ b/tensorflow/core/graph/shape_refiner_test.cc @@ -92,6 +92,33 @@ TEST(ShapeRefinerTest, BadShapes) { ASSERT_EQ("Dimensions must be equal, but are 1 and 2", s.error_message()); } +TEST(ShapeRefinerTest, SetShape) { + ShapeRefiner m; + + Scope root = Scope::NewRootScope(); + auto a = ops::Const(root, {{1.0f}, {2.0f}}); + + TF_ASSERT_OK(m.AddNode(a.node())); + + auto ic = m.GetContext(a.node()); + ASSERT_NE(nullptr, ic); + shape_inference::ShapeHandle h = ic->MakeShape({2, ic->UnknownDim()}); + TF_ASSERT_OK(m.SetShape(a.node(), 0, h)); + EXPECT_SHAPE("[2,?]", m, a, 0); + + // Out of range. + ASSERT_FALSE(m.SetShape(a.node(), 1, h).ok()); + ASSERT_FALSE(m.SetShape(a.node(), -1, h).ok()); + + auto b = ops::Const(root, {{1.0f}, {2.0f}}); + // Forget to add node first. + ASSERT_FALSE(m.SetShape(b.node(), 0, h).ok()); + + // Set an incompatible shape (3 vs 2) + h = ic->MakeShape({3, ic->UnknownDim()}); + ASSERT_FALSE(m.SetShape(a.node(), 0, h).ok()); +} + TEST(ShapeRefinerTest, PropagateConstants) { // Reduction dimension is a variable, so we don't know its value. // So the output shape value is unknown (though its rank is known).