TensorFlow: Add "SetShape" function to ShapeRefiner.
This will be used to eventually implement the same functionality as python's 'set_shape' function, once we add the ShapeRefiner to the C API. Change: 132093966
This commit is contained in:
parent
f366a3bcfd
commit
10451eb6cf
@ -113,6 +113,36 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeRefiner::SetShape(const Node* node, int output_port,
|
||||
shape_inference::ShapeHandle shape) {
|
||||
auto c = GetContext(node);
|
||||
if (c == nullptr) {
|
||||
return errors::Internal("Could not find context for ", node->name());
|
||||
}
|
||||
|
||||
if (output_port < 0 || output_port >= node->num_outputs()) {
|
||||
return errors::InvalidArgument(
|
||||
"output_port '", output_port, "' is out of range, ", "node '",
|
||||
node->name(), "' has ", node->num_outputs(), " outputs");
|
||||
}
|
||||
|
||||
// Check compatibility
|
||||
shape_inference::ShapeHandle existing_shape = c->output(output_port);
|
||||
shape_inference::ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->Merge(existing_shape, shape, &unused));
|
||||
|
||||
c->set_output(output_port, shape);
|
||||
|
||||
// TODO(vrv): Do we need to propagate the new shape through all
|
||||
// consumers that change their outputs? At the moment, python
|
||||
// does not do this, but this seems like a nice feature.
|
||||
|
||||
// TODO(vrv): We might need to keep track of the fact that the
|
||||
// existing shape is invalidated, in case we need to propagate
|
||||
// this information to remote workers.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeRefiner::ConstantValue(const Node* node, Tensor* tensor_storage,
|
||||
const Tensor** input_tensor) const {
|
||||
*input_tensor = nullptr;
|
||||
|
@ -46,6 +46,14 @@ class ShapeRefiner {
|
||||
// - The shape inference function returns an error.
|
||||
Status AddNode(const Node* node);
|
||||
|
||||
// Sets 'node's 'output_port' output to have shape 'shape'.
|
||||
//
|
||||
// Returns an error if 'node' was not previously added to this
|
||||
// object, if 'output_port' is invalid, or if 'shape' is
|
||||
// not compatible with the existing shape of the output.
|
||||
Status SetShape(const Node* node, int output_port,
|
||||
shape_inference::ShapeHandle shape);
|
||||
|
||||
// Returns the InferenceContext for 'node', if present.
|
||||
shape_inference::InferenceContext* GetContext(const Node* node) const {
|
||||
auto it = node_to_context_.find(node);
|
||||
|
@ -92,6 +92,33 @@ TEST(ShapeRefinerTest, BadShapes) {
|
||||
ASSERT_EQ("Dimensions must be equal, but are 1 and 2", s.error_message());
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, SetShape) {
|
||||
ShapeRefiner m;
|
||||
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto a = ops::Const(root, {{1.0f}, {2.0f}});
|
||||
|
||||
TF_ASSERT_OK(m.AddNode(a.node()));
|
||||
|
||||
auto ic = m.GetContext(a.node());
|
||||
ASSERT_NE(nullptr, ic);
|
||||
shape_inference::ShapeHandle h = ic->MakeShape({2, ic->UnknownDim()});
|
||||
TF_ASSERT_OK(m.SetShape(a.node(), 0, h));
|
||||
EXPECT_SHAPE("[2,?]", m, a, 0);
|
||||
|
||||
// Out of range.
|
||||
ASSERT_FALSE(m.SetShape(a.node(), 1, h).ok());
|
||||
ASSERT_FALSE(m.SetShape(a.node(), -1, h).ok());
|
||||
|
||||
auto b = ops::Const(root, {{1.0f}, {2.0f}});
|
||||
// Forget to add node first.
|
||||
ASSERT_FALSE(m.SetShape(b.node(), 0, h).ok());
|
||||
|
||||
// Set an incompatible shape (3 vs 2)
|
||||
h = ic->MakeShape({3, ic->UnknownDim()});
|
||||
ASSERT_FALSE(m.SetShape(a.node(), 0, h).ok());
|
||||
}
|
||||
|
||||
TEST(ShapeRefinerTest, PropagateConstants) {
|
||||
// Reduction dimension is a variable, so we don't know its value.
|
||||
// So the output shape value is unknown (though its rank is known).
|
||||
|
Loading…
Reference in New Issue
Block a user