Add C++ shape inference for Pack, Unpack, and Const.

Add GetAttr to shape_inference::InferenceContext.
Allow setting NodeDef in shape_inference_testutil INFER calls (with new
INFER*_WITH_DEF macro).  Fix a bug that caused a crash when an INFER..ERROR
macro called a shape inference function that did not return an error.
Change: 126350221
This commit is contained in:
A. Unique TensorFlower 2016-06-30 14:21:08 -08:00 committed by TensorFlower Gardener
parent b5c493301a
commit 605aa53d2b
8 changed files with 315 additions and 24 deletions

View File

@ -1812,10 +1812,13 @@ tf_cc_test(
],
)
tf_cc_test(
name = "ops/math_ops_test",
tf_cc_tests(
size = "small",
linkstatic = tf_kernel_tests_linkstatic(),
tests = [
"ops/array_ops_test.cc",
"ops/math_ops_test.cc",
],
deps = [
":core",
":core_cpu",

View File

@ -25,9 +25,9 @@ constexpr int32 InferenceContext::kUnknownRank;
constexpr int64 InferenceContext::kUnknownDim;
InferenceContext::InferenceContext(
const std::vector<string>& input_shapes, int num_outputs,
const std::vector<const Tensor*>& input_tensors)
: input_tensors_(input_tensors) {
const NodeDef* node_def, const std::vector<string>& input_shapes,
int num_outputs, const std::vector<const Tensor*>& input_tensors)
: input_tensors_(input_tensors), node_def_(*CHECK_NOTNULL(node_def)) {
for (const string& spec : input_shapes) {
if (spec == "?") {
inputs_.push_back(CreateUnknownShape());

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@ -80,7 +82,10 @@ class InferenceContext {
// the same Dimension*.
//
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
InferenceContext(const std::vector<string>& input_shapes, int num_outputs,
//
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
InferenceContext(const NodeDef* node_def,
const std::vector<string>& input_shapes, int num_outputs,
const std::vector<const Tensor*>& input_tensors = {});
~InferenceContext();
@ -162,6 +167,12 @@ class InferenceContext {
const Dimension* CreateDim(int64 value);
const Dimension* CreateUnknownDim();
// Look up the attr for the NodeDef being evaluated with name attr_name and
// set *value to its value. If no attr with attr_name is found in def(), or
// the attr does not have a matching type, a non-ok status will be returned.
template <class T>
Status GetAttr(StringPiece attr_name, T* value) const;
private:
Status ReturnUnknownShape(const Shape** out) {
*out = CreateUnknownShape();
@ -181,9 +192,14 @@ class InferenceContext {
std::vector<const Tensor*> input_tensors_;
std::vector<const Shape*> outputs_;
const NodeDef& node_def_;
TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
};
// -----------------------------------------------------------------------------
// Template and inline method implementations, please ignore
inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
inline Dimension::Dimension(int64 value) : value_(value) {}
@ -191,6 +207,11 @@ inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
inline Shape::Shape(const std::vector<const Dimension*> dims)
: rank_(dims.size()), dims_(dims) {}
template <class T>
Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
return GetNodeAttr(node_def_, attr_name, value);
}
} // namespace shape_inference
} // namespace tensorflow

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/test.h"
@ -21,7 +23,8 @@ namespace tensorflow {
namespace shape_inference {
TEST(ShapeInferenceTest, RankAndDimInspection) {
InferenceContext c({"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
NodeDef def;
InferenceContext c(&def, {"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(2, c.num_outputs());
@ -54,7 +57,8 @@ TEST(ShapeInferenceTest, RankAndDimInspection) {
}
TEST(ShapeInferenceTest, WithRank) {
InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
NodeDef def;
InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -91,7 +95,8 @@ TEST(ShapeInferenceTest, WithRank) {
}
TEST(ShapeInferenceTest, WithRankAtLeast) {
InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
NodeDef def;
InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -125,7 +130,8 @@ TEST(ShapeInferenceTest, WithRankAtLeast) {
}
TEST(ShapeInferenceTest, WithValue) {
InferenceContext c({"[1,?]"}, 2 /* num_outputs */);
NodeDef def;
InferenceContext c(&def, {"[1,?]"}, 2 /* num_outputs */);
auto d0 = c.Dim(c.input(0), 0);
auto d1 = c.Dim(c.input(0), 1);
@ -163,7 +169,8 @@ TEST(ShapeInferenceTest, WithValue) {
}
TEST(ShapeInferenceTest, MergeDim) {
InferenceContext c({"[2,?,2,1,?]"}, 2 /* num_outputs */);
NodeDef def;
InferenceContext c(&def, {"[2,?,2,1,?]"}, 2 /* num_outputs */);
auto d2 = c.Dim(c.input(0), 0);
auto d_unknown = c.Dim(c.input(0), 1);
@ -202,7 +209,9 @@ TEST(ShapeInferenceTest, MergeDim) {
}
TEST(ShapeInferenceTest, MergeShape) {
InferenceContext c({"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
NodeDef def;
InferenceContext c(&def,
{"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
2 /* num_outputs */);
auto s_unknown = c.input(0);
@ -260,7 +269,8 @@ TEST(ShapeInferenceTest, MergeShape) {
}
TEST(ShapeInferenceTest, Subshape) {
InferenceContext c({"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
NodeDef def;
InferenceContext c(&def, {"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
const Shape* unknown = c.input(1);
const Shape* out;
@ -297,7 +307,8 @@ TEST(ShapeInferenceTest, Subshape) {
}
TEST(ShapeInferenceTest, Concatenate) {
InferenceContext c({"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
NodeDef def;
InferenceContext c(&def, {"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -322,7 +333,8 @@ TEST(ShapeInferenceTest, Concatenate) {
}
TEST(ShapeInferenceTest, CreateShape) {
InferenceContext c({"[1,2,3,?,5]"}, 2 /* num_outputs */);
NodeDef def;
InferenceContext c(&def, {"[1,2,3,?,5]"}, 2 /* num_outputs */);
std::vector<const Dimension*> dims;
auto in0 = c.input(0);
@ -341,7 +353,8 @@ TEST(ShapeInferenceTest, CreateShape) {
}
TEST(ShapeInferenceTest, CreateUnknownShape) {
InferenceContext c({}, 2 /* num_outputs */);
NodeDef def;
InferenceContext c(&def, {}, 2 /* num_outputs */);
auto u0 = c.CreateUnknownShape();
auto u1 = c.CreateUnknownShape();
@ -352,7 +365,8 @@ TEST(ShapeInferenceTest, CreateUnknownShape) {
TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
auto create = [](Tensor* t) {
InferenceContext c({"?"}, 0 /* num_outputs */, {t});
NodeDef def;
InferenceContext c(&def, {"?"}, 0 /* num_outputs */, {t});
const Shape* out;
Status s = c.CreateShapeFromShapeTensor(0, &out);
if (s.ok()) {
@ -386,7 +400,8 @@ TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
}
TEST(ShapeInferenceTest, CreateDim) {
InferenceContext c({}, 2 /* num_outputs */);
NodeDef def;
InferenceContext c(&def, {}, 2 /* num_outputs */);
auto* d0 = c.CreateDim(1);
auto* d1 = c.CreateDim(1);
@ -398,7 +413,8 @@ TEST(ShapeInferenceTest, CreateDim) {
}
TEST(ShapeInferenceTest, CreateUnknownDim) {
InferenceContext c({}, 2 /* num_outputs */);
NodeDef def;
InferenceContext c(&def, {}, 2 /* num_outputs */);
auto* d0 = c.CreateUnknownDim();
auto* d1 = c.CreateUnknownDim();
@ -410,12 +426,29 @@ TEST(ShapeInferenceTest, CreateUnknownDim) {
TEST(ShapeInferenceTest, InputTensors) {
const Tensor t1 = tensorflow::test::AsTensor<float>({10});
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
InferenceContext c({"[1]", "[2]", "[3]"}, 2 /* num_outputs */, {&t1, &t2});
NodeDef def;
InferenceContext c(&def, {"[1]", "[2]", "[3]"}, 2 /* num_outputs */,
{&t1, &t2});
EXPECT_TRUE(c.input_tensor(0) == &t1);
EXPECT_TRUE(c.input_tensor(1) == &t2);
EXPECT_TRUE(c.input_tensor(2) == nullptr);
}
TEST(ShapeInferenceTest, GetAttr) {
OpRegistrationData op_reg_data;
CHECK(OpDefBuilder("dummy").Attr("foo:string").Finalize(&op_reg_data).ok());
NodeDef def;
CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
.Attr("foo", "bar")
.Finalize(&def)
.ok());
InferenceContext c(&def, {}, 2 /* num_outputs */);
string value;
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
EXPECT_EQ("bar", value);
}
} // namespace shape_inference
} // namespace tensorflow

View File

@ -29,13 +29,18 @@ using shape_inference::Shape;
using errors::Unknown;
Status InferShapes(const string& op_name, const string& ins,
const string& expected_outs) {
const string& expected_outs, const NodeDef* node_def) {
const OpRegistrationData* op_reg_data;
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op_name, &op_reg_data));
const int num_outputs = op_reg_data->op_def.output_arg_size();
std::vector<string> ins_v = str_util::Split(ins, ';');
shape_inference::InferenceContext c(ins_v, num_outputs);
std::unique_ptr<const NodeDef> new_node_def;
if (node_def == nullptr) {
new_node_def.reset(new NodeDef);
node_def = new_node_def.get();
}
shape_inference::InferenceContext c(node_def, ins_v, num_outputs);
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c));
std::unordered_map<const Dimension*, std::pair<int, int>>

View File

@ -23,6 +23,8 @@ limitations under the License.
namespace tensorflow {
class NodeDef;
// Run shape inference for <op_name>, given inputs specified by <ins>
// and returns an error if the inferred shape does not match expected_outs.
//
@ -45,11 +47,16 @@ namespace tensorflow {
// <expected_outs> can be "e"; this is used to indicate that shape inference
// should have failed.
Status InferShapes(const string& op_name, const string& ins,
const string& expected_outs);
const string& expected_outs,
const NodeDef* node_def = nullptr);
#define INFER_OK(op, i, o) EXPECT_EQ("", InferShapes(op, i, o).error_message())
#define INFER_ERROR(s, op, i) \
EXPECT_EQ(s, InferShapes(op, i, "x").error_message())
EXPECT_EQ(s, InferShapes(op, i, "e").error_message())
#define INFER_OK_WITH_DEF(op, nd, i, o) \
EXPECT_EQ("", InferShapes(op, i, o, nd).error_message())
#define INFER_ERROR_WITH_DEF(s, op, nd, i) \
EXPECT_EQ(s, InferShapes(op, i, "e", nd).error_message())
} // namespace tensorflow

View File

@ -14,17 +14,67 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/util/mirror_pad_mode.h"
#include "tensorflow/core/util/padding.h"
namespace tensorflow {
typedef shape_inference::Dimension Dimension;
typedef shape_inference::InferenceContext InferenceContext;
typedef shape_inference::Shape Shape;
namespace {
Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
int32* axis) {
TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
-1 * rank_after_pack, ",", rank_after_pack,
")");
}
if (*axis < 0) *axis = (rank_after_pack + *axis);
return Status::OK();
}
} // namespace
REGISTER_OP("Pack")
.Input("values: N * T")
.Output("output: T")
.Attr("N: int >= 1")
.Attr("T: type")
.Attr("axis: int = 0")
.SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
// Validate shapes of all inputs are compatible
const Shape* cur = c->input(c->num_inputs() - 1);
for (int i = c->num_inputs() - 2; i >= 0; --i) {
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
"From merging shape ", i,
" with other shapes.");
}
if (!c->RankKnown(cur)) {
c->set_output(0, c->CreateUnknownShape());
return Status::OK();
}
// Determine the axis that will be added, converting from negative
// axes to a positive point per negative indexing rules.
int32 rank = c->Rank(cur);
int32 axis;
TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
// Copy all dimensions over, inserting a dimension of value #inputs
// at <axis>.
std::vector<const Dimension*> dims;
int index = 0;
while (index < axis) dims.push_back(c->Dim(cur, index++));
dims.push_back(c->CreateDim(c->num_inputs()));
while (index < rank) dims.push_back(c->Dim(cur, index++));
c->set_output(0, c->CreateShape(dims));
return Status::OK();
}))
.Doc(R"doc(
Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
@ -61,6 +111,29 @@ REGISTER_OP("Unpack")
.Attr("num: int >= 0")
.Attr("T: type")
.Attr("axis: int = 0")
.SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
const Shape* s = c->input(0);
const Shape* out;
if (c->RankKnown(s)) {
// Determine the axis that will be removed, converting from negative
// axes to a positive point per negative indexing rules.
int32 rank = c->Rank(s);
int32 axis;
TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
// Copy all dimensions, removing the <axis> dimension.
std::vector<const Dimension*> dims;
for (int i = 0; i < rank; ++i) {
if (i != axis) dims.push_back(c->Dim(s, i));
}
out = c->CreateShape(dims);
} else {
// All outputs are the same shape, but it's not known.
out = c->CreateUnknownShape();
}
for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
return Status::OK();
}))
.Doc(R"doc(
Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
@ -154,6 +227,18 @@ REGISTER_OP("Const")
.Output("output: dtype")
.Attr("value: tensor")
.Attr("dtype: type")
.SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
const TensorProto* proto = nullptr;
TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
TensorShape shape(proto->tensor_shape());
std::vector<const Dimension*> dims;
for (int i = 0; i < shape.dims(); ++i) {
dims.push_back(c->CreateDim(shape.dim_size(i)));
}
c->set_output(0, c->CreateShape(dims));
return Status::OK();
}))
.Doc(R"doc(
Returns a constant tensor.

View File

@ -0,0 +1,137 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
TEST(ArrayOpsTest, Pack_ShapeFn) {
std::unique_ptr<NodeDef> def_storage(new NodeDef);
NodeDef* def = def_storage.get();
auto set_axis = [def](int axis) {
TF_CHECK_OK(NodeDefBuilder("test", "Pack")
.Input({{"a", 0, DT_FLOAT}})
.Attr("axis", axis)
.Finalize(def));
};
const char op[] = "Pack";
set_axis(0);
INFER_OK_WITH_DEF(op, def, "?;?;?", "?");
for (int axis : {0, -3}) {
set_axis(axis);
INFER_OK_WITH_DEF(op, def, "?;?", "?");
INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[3,d0_0|d1_0,d0_1|d1_1]");
INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[3,d1_0,d0_1|d1_1]");
INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[3,d1_0,d1_1]");
}
for (int axis : {1, -2}) {
set_axis(axis);
INFER_OK_WITH_DEF(op, def, "?;?", "?");
INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[d0_0|d1_0,3,d0_1|d1_1]");
INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[d1_0,3,d0_1|d1_1]");
INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[d1_0,3,d1_1]");
}
for (int axis : {2, -1}) {
set_axis(axis);
INFER_OK_WITH_DEF(op, def, "?;?", "?");
INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[d0_0|d1_0,d0_1|d1_1,3]");
INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[d1_0,d0_1|d1_1,3]");
INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[d1_0,d1_1,3]");
}
set_axis(-4);
INFER_ERROR_WITH_DEF("Invalid axis: -4; must be in [-3,3)", op, def,
"[1,3];[1,3];?");
set_axis(3);
INFER_ERROR_WITH_DEF("Invalid axis: 3; must be in [-3,3)", op, def,
"[1,3];[1,3];?");
set_axis(0);
INFER_ERROR_WITH_DEF(("Shapes must be equal rank, but are 3 and 2"
"\n\tFrom merging shape 0 with other shapes."),
op, def, "[1,2,3];?;[1,4]");
}
TEST(ArrayOpsTest, UnPack_ShapeFn) {
std::unique_ptr<NodeDef> def_storage(new NodeDef);
NodeDef* def = def_storage.get();
auto set_axis = [def](int axis) {
TF_CHECK_OK(NodeDefBuilder("test", "Unpack")
.Input("a", 0, DT_FLOAT)
.Attr("axis", axis)
.Finalize(def));
};
const char op[] = "Unpack";
set_axis(0);
INFER_OK_WITH_DEF(op, def, "?;?;?", "?");
for (int axis : {0, -3}) {
set_axis(axis);
INFER_OK_WITH_DEF(op, def, "?", "?");
INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_1,d0_2]");
INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_1,d0_2]");
}
for (int axis : {1, -2}) {
set_axis(axis);
INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_0,d0_2]");
INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_0,d0_2]");
}
for (int axis : {2, -1}) {
set_axis(axis);
INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_0,d0_1]");
INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_0,d0_1]");
}
set_axis(-4);
INFER_ERROR_WITH_DEF("Invalid axis: -4; must be in [-3,3)", op, def,
"[1,2,3]");
set_axis(3);
INFER_ERROR_WITH_DEF("Invalid axis: 3; must be in [-3,3)", op, def,
"[1,2,3]");
}
TEST(ArrayOpsTest, Const_ShapeFn) {
std::unique_ptr<NodeDef> def_storage(new NodeDef);
NodeDef* def = def_storage.get();
TensorProto tensor_proto;
auto* shape_proto = tensor_proto.mutable_tensor_shape();
auto rebuild_node_def = [def, &tensor_proto]() {
TF_CHECK_OK(NodeDefBuilder("test", "Const")
.Attr("value", tensor_proto)
.Finalize(def));
};
const char op[] = "Const";
TensorShape{}.AsProto(shape_proto);
rebuild_node_def();
INFER_OK_WITH_DEF(op, def, "", "[]");
TensorShape{1, 2, 3, 4}.AsProto(shape_proto);
rebuild_node_def();
INFER_OK_WITH_DEF(op, def, "", "[1,2,3,4]");
shape_proto->add_dim()->set_size(-1);
rebuild_node_def();
INFER_ERROR_WITH_DEF("Shape [1,2,3,4,-1] has negative dimensions", op, def,
"");
}
} // end namespace tensorflow