TensorFlow: Add "SetShape" function to ShapeRefiner.

This will be used to eventually implement the same functionality
as python's 'set_shape' function, once we add the ShapeRefiner to
the C API.
Change: 132093966
This commit is contained in:
Vijay Vasudevan 2016-09-02 11:56:56 -08:00 committed by TensorFlower Gardener
parent f366a3bcfd
commit 10451eb6cf
3 changed files with 65 additions and 0 deletions

View File

@ -113,6 +113,36 @@ Status ShapeRefiner::AddNode(const Node* node) {
return Status::OK();
}
Status ShapeRefiner::SetShape(const Node* node, int output_port,
shape_inference::ShapeHandle shape) {
auto c = GetContext(node);
if (c == nullptr) {
return errors::Internal("Could not find context for ", node->name());
}
if (output_port < 0 || output_port >= node->num_outputs()) {
return errors::InvalidArgument(
"output_port '", output_port, "' is out of range, ", "node '",
node->name(), "' has ", node->num_outputs(), " outputs");
}
// Check compatibility
shape_inference::ShapeHandle existing_shape = c->output(output_port);
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &unused));
c->set_output(output_port, shape);
// TODO(vrv): Do we need to propagate the new shape through all
// consumers that change their outputs? At the moment, python
// does not do this, but this seems like a nice feature.
// TODO(vrv): We might need to keep track of the fact that the
// existing shape is invalidated, in case we need to propagate
// this information to remote workers.
return Status::OK();
}
Status ShapeRefiner::ConstantValue(const Node* node, Tensor* tensor_storage,
const Tensor** input_tensor) const {
*input_tensor = nullptr;

View File

@ -46,6 +46,14 @@ class ShapeRefiner {
// - The shape inference function returns an error.
Status AddNode(const Node* node);
// Sets 'node's 'output_port' output to have shape 'shape'.
//
// Returns an error if 'node' was not previously added to this
// object, if 'output_port' is invalid, or if 'shape' is
// not compatible with the existing shape of the output.
Status SetShape(const Node* node, int output_port,
shape_inference::ShapeHandle shape);
// Returns the InferenceContext for 'node', if present.
shape_inference::InferenceContext* GetContext(const Node* node) const {
auto it = node_to_context_.find(node);

View File

@ -92,6 +92,33 @@ TEST(ShapeRefinerTest, BadShapes) {
ASSERT_EQ("Dimensions must be equal, but are 1 and 2", s.error_message());
}
TEST(ShapeRefinerTest, SetShape) {
ShapeRefiner m;
Scope root = Scope::NewRootScope();
auto a = ops::Const(root, {{1.0f}, {2.0f}});
TF_ASSERT_OK(m.AddNode(a.node()));
auto ic = m.GetContext(a.node());
ASSERT_NE(nullptr, ic);
shape_inference::ShapeHandle h = ic->MakeShape({2, ic->UnknownDim()});
TF_ASSERT_OK(m.SetShape(a.node(), 0, h));
EXPECT_SHAPE("[2,?]", m, a, 0);
// Out of range.
ASSERT_FALSE(m.SetShape(a.node(), 1, h).ok());
ASSERT_FALSE(m.SetShape(a.node(), -1, h).ok());
auto b = ops::Const(root, {{1.0f}, {2.0f}});
// Forget to add node first.
ASSERT_FALSE(m.SetShape(b.node(), 0, h).ok());
// Set an incompatible shape (3 vs 2)
h = ic->MakeShape({3, ic->UnknownDim()});
ASSERT_FALSE(m.SetShape(a.node(), 0, h).ok());
}
TEST(ShapeRefinerTest, PropagateConstants) {
// Reduction dimension is a variable, so we don't know its value.
// So the output shape value is unknown (though its rank is known).