Add C++ shape inference for Pack, Unpack, and Const.
Add GetAttr to shape_inference::InferenceContext. Allow setting NodeDef in shape_inference_testutil INFER calls (with new INFER*_WITH_DEF macro). Fix a bug that caused a crash when an INFER..ERROR macro called a shape inference function that did not return an error. Change: 126350221
This commit is contained in:
parent
b5c493301a
commit
605aa53d2b
@ -1812,10 +1812,13 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_tests(
|
||||||
name = "ops/math_ops_test",
|
|
||||||
size = "small",
|
size = "small",
|
||||||
linkstatic = tf_kernel_tests_linkstatic(),
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
|
tests = [
|
||||||
|
"ops/array_ops_test.cc",
|
||||||
|
"ops/math_ops_test.cc",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":core",
|
":core",
|
||||||
":core_cpu",
|
":core_cpu",
|
||||||
|
@ -25,9 +25,9 @@ constexpr int32 InferenceContext::kUnknownRank;
|
|||||||
constexpr int64 InferenceContext::kUnknownDim;
|
constexpr int64 InferenceContext::kUnknownDim;
|
||||||
|
|
||||||
InferenceContext::InferenceContext(
|
InferenceContext::InferenceContext(
|
||||||
const std::vector<string>& input_shapes, int num_outputs,
|
const NodeDef* node_def, const std::vector<string>& input_shapes,
|
||||||
const std::vector<const Tensor*>& input_tensors)
|
int num_outputs, const std::vector<const Tensor*>& input_tensors)
|
||||||
: input_tensors_(input_tensors) {
|
: input_tensors_(input_tensors), node_def_(*CHECK_NOTNULL(node_def)) {
|
||||||
for (const string& spec : input_shapes) {
|
for (const string& spec : input_shapes) {
|
||||||
if (spec == "?") {
|
if (spec == "?") {
|
||||||
inputs_.push_back(CreateUnknownShape());
|
inputs_.push_back(CreateUnknownShape());
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
@ -80,7 +82,10 @@ class InferenceContext {
|
|||||||
// the same Dimension*.
|
// the same Dimension*.
|
||||||
//
|
//
|
||||||
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
|
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
|
||||||
InferenceContext(const std::vector<string>& input_shapes, int num_outputs,
|
//
|
||||||
|
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
|
||||||
|
InferenceContext(const NodeDef* node_def,
|
||||||
|
const std::vector<string>& input_shapes, int num_outputs,
|
||||||
const std::vector<const Tensor*>& input_tensors = {});
|
const std::vector<const Tensor*>& input_tensors = {});
|
||||||
~InferenceContext();
|
~InferenceContext();
|
||||||
|
|
||||||
@ -162,6 +167,12 @@ class InferenceContext {
|
|||||||
const Dimension* CreateDim(int64 value);
|
const Dimension* CreateDim(int64 value);
|
||||||
const Dimension* CreateUnknownDim();
|
const Dimension* CreateUnknownDim();
|
||||||
|
|
||||||
|
// Look up the attr for the NodeDef being evaluated with name attr_name and
|
||||||
|
// set *value to its value. If no attr with attr_name is found in def(), or
|
||||||
|
// the attr does not have a matching type, a non-ok status will be returned.
|
||||||
|
template <class T>
|
||||||
|
Status GetAttr(StringPiece attr_name, T* value) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status ReturnUnknownShape(const Shape** out) {
|
Status ReturnUnknownShape(const Shape** out) {
|
||||||
*out = CreateUnknownShape();
|
*out = CreateUnknownShape();
|
||||||
@ -181,9 +192,14 @@ class InferenceContext {
|
|||||||
std::vector<const Tensor*> input_tensors_;
|
std::vector<const Tensor*> input_tensors_;
|
||||||
std::vector<const Shape*> outputs_;
|
std::vector<const Shape*> outputs_;
|
||||||
|
|
||||||
|
const NodeDef& node_def_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
|
TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Template and inline method implementations, please ignore
|
||||||
|
|
||||||
inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
|
inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
|
||||||
inline Dimension::Dimension(int64 value) : value_(value) {}
|
inline Dimension::Dimension(int64 value) : value_(value) {}
|
||||||
|
|
||||||
@ -191,6 +207,11 @@ inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
|
|||||||
inline Shape::Shape(const std::vector<const Dimension*> dims)
|
inline Shape::Shape(const std::vector<const Dimension*> dims)
|
||||||
: rank_(dims.size()), dims_(dims) {}
|
: rank_(dims.size()), dims_(dims) {}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
|
||||||
|
return GetNodeAttr(node_def_, attr_name, value);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace shape_inference
|
} // namespace shape_inference
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_def_builder.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
@ -21,7 +23,8 @@ namespace tensorflow {
|
|||||||
namespace shape_inference {
|
namespace shape_inference {
|
||||||
|
|
||||||
TEST(ShapeInferenceTest, RankAndDimInspection) {
|
TEST(ShapeInferenceTest, RankAndDimInspection) {
|
||||||
InferenceContext c({"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
|
||||||
EXPECT_EQ(3, c.num_inputs());
|
EXPECT_EQ(3, c.num_inputs());
|
||||||
EXPECT_EQ(2, c.num_outputs());
|
EXPECT_EQ(2, c.num_outputs());
|
||||||
|
|
||||||
@ -54,7 +57,8 @@ TEST(ShapeInferenceTest, RankAndDimInspection) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeInferenceTest, WithRank) {
|
TEST(ShapeInferenceTest, WithRank) {
|
||||||
InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
|
||||||
|
|
||||||
auto in0 = c.input(0);
|
auto in0 = c.input(0);
|
||||||
auto in1 = c.input(1);
|
auto in1 = c.input(1);
|
||||||
@ -91,7 +95,8 @@ TEST(ShapeInferenceTest, WithRank) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeInferenceTest, WithRankAtLeast) {
|
TEST(ShapeInferenceTest, WithRankAtLeast) {
|
||||||
InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
|
||||||
|
|
||||||
auto in0 = c.input(0);
|
auto in0 = c.input(0);
|
||||||
auto in1 = c.input(1);
|
auto in1 = c.input(1);
|
||||||
@ -125,7 +130,8 @@ TEST(ShapeInferenceTest, WithRankAtLeast) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeInferenceTest, WithValue) {
|
TEST(ShapeInferenceTest, WithValue) {
|
||||||
InferenceContext c({"[1,?]"}, 2 /* num_outputs */);
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {"[1,?]"}, 2 /* num_outputs */);
|
||||||
|
|
||||||
auto d0 = c.Dim(c.input(0), 0);
|
auto d0 = c.Dim(c.input(0), 0);
|
||||||
auto d1 = c.Dim(c.input(0), 1);
|
auto d1 = c.Dim(c.input(0), 1);
|
||||||
@ -163,7 +169,8 @@ TEST(ShapeInferenceTest, WithValue) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeInferenceTest, MergeDim) {
|
TEST(ShapeInferenceTest, MergeDim) {
|
||||||
InferenceContext c({"[2,?,2,1,?]"}, 2 /* num_outputs */);
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {"[2,?,2,1,?]"}, 2 /* num_outputs */);
|
||||||
|
|
||||||
auto d2 = c.Dim(c.input(0), 0);
|
auto d2 = c.Dim(c.input(0), 0);
|
||||||
auto d_unknown = c.Dim(c.input(0), 1);
|
auto d_unknown = c.Dim(c.input(0), 1);
|
||||||
@ -202,7 +209,9 @@ TEST(ShapeInferenceTest, MergeDim) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeInferenceTest, MergeShape) {
|
TEST(ShapeInferenceTest, MergeShape) {
|
||||||
InferenceContext c({"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
|
NodeDef def;
|
||||||
|
InferenceContext c(&def,
|
||||||
|
{"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
|
||||||
2 /* num_outputs */);
|
2 /* num_outputs */);
|
||||||
|
|
||||||
auto s_unknown = c.input(0);
|
auto s_unknown = c.input(0);
|
||||||
@ -260,7 +269,8 @@ TEST(ShapeInferenceTest, MergeShape) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeInferenceTest, Subshape) {
|
TEST(ShapeInferenceTest, Subshape) {
|
||||||
InferenceContext c({"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
|
||||||
|
|
||||||
const Shape* unknown = c.input(1);
|
const Shape* unknown = c.input(1);
|
||||||
const Shape* out;
|
const Shape* out;
|
||||||
@ -297,7 +307,8 @@ TEST(ShapeInferenceTest, Subshape) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeInferenceTest, Concatenate) {
|
TEST(ShapeInferenceTest, Concatenate) {
|
||||||
InferenceContext c({"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
|
||||||
|
|
||||||
auto in0 = c.input(0);
|
auto in0 = c.input(0);
|
||||||
auto in1 = c.input(1);
|
auto in1 = c.input(1);
|
||||||
@ -322,7 +333,8 @@ TEST(ShapeInferenceTest, Concatenate) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeInferenceTest, CreateShape) {
|
TEST(ShapeInferenceTest, CreateShape) {
|
||||||
InferenceContext c({"[1,2,3,?,5]"}, 2 /* num_outputs */);
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {"[1,2,3,?,5]"}, 2 /* num_outputs */);
|
||||||
|
|
||||||
std::vector<const Dimension*> dims;
|
std::vector<const Dimension*> dims;
|
||||||
auto in0 = c.input(0);
|
auto in0 = c.input(0);
|
||||||
@ -341,7 +353,8 @@ TEST(ShapeInferenceTest, CreateShape) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeInferenceTest, CreateUnknownShape) {
|
TEST(ShapeInferenceTest, CreateUnknownShape) {
|
||||||
InferenceContext c({}, 2 /* num_outputs */);
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {}, 2 /* num_outputs */);
|
||||||
|
|
||||||
auto u0 = c.CreateUnknownShape();
|
auto u0 = c.CreateUnknownShape();
|
||||||
auto u1 = c.CreateUnknownShape();
|
auto u1 = c.CreateUnknownShape();
|
||||||
@ -352,7 +365,8 @@ TEST(ShapeInferenceTest, CreateUnknownShape) {
|
|||||||
|
|
||||||
TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
|
TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
|
||||||
auto create = [](Tensor* t) {
|
auto create = [](Tensor* t) {
|
||||||
InferenceContext c({"?"}, 0 /* num_outputs */, {t});
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {"?"}, 0 /* num_outputs */, {t});
|
||||||
const Shape* out;
|
const Shape* out;
|
||||||
Status s = c.CreateShapeFromShapeTensor(0, &out);
|
Status s = c.CreateShapeFromShapeTensor(0, &out);
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
@ -386,7 +400,8 @@ TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeInferenceTest, CreateDim) {
|
TEST(ShapeInferenceTest, CreateDim) {
|
||||||
InferenceContext c({}, 2 /* num_outputs */);
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {}, 2 /* num_outputs */);
|
||||||
|
|
||||||
auto* d0 = c.CreateDim(1);
|
auto* d0 = c.CreateDim(1);
|
||||||
auto* d1 = c.CreateDim(1);
|
auto* d1 = c.CreateDim(1);
|
||||||
@ -398,7 +413,8 @@ TEST(ShapeInferenceTest, CreateDim) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ShapeInferenceTest, CreateUnknownDim) {
|
TEST(ShapeInferenceTest, CreateUnknownDim) {
|
||||||
InferenceContext c({}, 2 /* num_outputs */);
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {}, 2 /* num_outputs */);
|
||||||
|
|
||||||
auto* d0 = c.CreateUnknownDim();
|
auto* d0 = c.CreateUnknownDim();
|
||||||
auto* d1 = c.CreateUnknownDim();
|
auto* d1 = c.CreateUnknownDim();
|
||||||
@ -410,12 +426,29 @@ TEST(ShapeInferenceTest, CreateUnknownDim) {
|
|||||||
TEST(ShapeInferenceTest, InputTensors) {
|
TEST(ShapeInferenceTest, InputTensors) {
|
||||||
const Tensor t1 = tensorflow::test::AsTensor<float>({10});
|
const Tensor t1 = tensorflow::test::AsTensor<float>({10});
|
||||||
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
|
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
|
||||||
InferenceContext c({"[1]", "[2]", "[3]"}, 2 /* num_outputs */, {&t1, &t2});
|
NodeDef def;
|
||||||
|
InferenceContext c(&def, {"[1]", "[2]", "[3]"}, 2 /* num_outputs */,
|
||||||
|
{&t1, &t2});
|
||||||
|
|
||||||
EXPECT_TRUE(c.input_tensor(0) == &t1);
|
EXPECT_TRUE(c.input_tensor(0) == &t1);
|
||||||
EXPECT_TRUE(c.input_tensor(1) == &t2);
|
EXPECT_TRUE(c.input_tensor(1) == &t2);
|
||||||
EXPECT_TRUE(c.input_tensor(2) == nullptr);
|
EXPECT_TRUE(c.input_tensor(2) == nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ShapeInferenceTest, GetAttr) {
|
||||||
|
OpRegistrationData op_reg_data;
|
||||||
|
CHECK(OpDefBuilder("dummy").Attr("foo:string").Finalize(&op_reg_data).ok());
|
||||||
|
NodeDef def;
|
||||||
|
CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
|
||||||
|
.Attr("foo", "bar")
|
||||||
|
.Finalize(&def)
|
||||||
|
.ok());
|
||||||
|
|
||||||
|
InferenceContext c(&def, {}, 2 /* num_outputs */);
|
||||||
|
string value;
|
||||||
|
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
|
||||||
|
EXPECT_EQ("bar", value);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace shape_inference
|
} // namespace shape_inference
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -29,13 +29,18 @@ using shape_inference::Shape;
|
|||||||
using errors::Unknown;
|
using errors::Unknown;
|
||||||
|
|
||||||
Status InferShapes(const string& op_name, const string& ins,
|
Status InferShapes(const string& op_name, const string& ins,
|
||||||
const string& expected_outs) {
|
const string& expected_outs, const NodeDef* node_def) {
|
||||||
const OpRegistrationData* op_reg_data;
|
const OpRegistrationData* op_reg_data;
|
||||||
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op_name, &op_reg_data));
|
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op_name, &op_reg_data));
|
||||||
const int num_outputs = op_reg_data->op_def.output_arg_size();
|
const int num_outputs = op_reg_data->op_def.output_arg_size();
|
||||||
|
|
||||||
std::vector<string> ins_v = str_util::Split(ins, ';');
|
std::vector<string> ins_v = str_util::Split(ins, ';');
|
||||||
shape_inference::InferenceContext c(ins_v, num_outputs);
|
std::unique_ptr<const NodeDef> new_node_def;
|
||||||
|
if (node_def == nullptr) {
|
||||||
|
new_node_def.reset(new NodeDef);
|
||||||
|
node_def = new_node_def.get();
|
||||||
|
}
|
||||||
|
shape_inference::InferenceContext c(node_def, ins_v, num_outputs);
|
||||||
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c));
|
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c));
|
||||||
|
|
||||||
std::unordered_map<const Dimension*, std::pair<int, int>>
|
std::unordered_map<const Dimension*, std::pair<int, int>>
|
||||||
|
@ -23,6 +23,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class NodeDef;
|
||||||
|
|
||||||
// Run shape inference for <op_name>, given inputs specified by <ins>
|
// Run shape inference for <op_name>, given inputs specified by <ins>
|
||||||
// and returns an error if the inferred shape does not match expected_outs.
|
// and returns an error if the inferred shape does not match expected_outs.
|
||||||
//
|
//
|
||||||
@ -45,11 +47,16 @@ namespace tensorflow {
|
|||||||
// <expected_outs> can be "e"; this is used to indicate that shape inference
|
// <expected_outs> can be "e"; this is used to indicate that shape inference
|
||||||
// should have failed.
|
// should have failed.
|
||||||
Status InferShapes(const string& op_name, const string& ins,
|
Status InferShapes(const string& op_name, const string& ins,
|
||||||
const string& expected_outs);
|
const string& expected_outs,
|
||||||
|
const NodeDef* node_def = nullptr);
|
||||||
|
|
||||||
#define INFER_OK(op, i, o) EXPECT_EQ("", InferShapes(op, i, o).error_message())
|
#define INFER_OK(op, i, o) EXPECT_EQ("", InferShapes(op, i, o).error_message())
|
||||||
#define INFER_ERROR(s, op, i) \
|
#define INFER_ERROR(s, op, i) \
|
||||||
EXPECT_EQ(s, InferShapes(op, i, "x").error_message())
|
EXPECT_EQ(s, InferShapes(op, i, "e").error_message())
|
||||||
|
#define INFER_OK_WITH_DEF(op, nd, i, o) \
|
||||||
|
EXPECT_EQ("", InferShapes(op, i, o, nd).error_message())
|
||||||
|
#define INFER_ERROR_WITH_DEF(s, op, nd, i) \
|
||||||
|
EXPECT_EQ(s, InferShapes(op, i, "e", nd).error_message())
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -14,17 +14,67 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/util/mirror_pad_mode.h"
|
#include "tensorflow/core/util/mirror_pad_mode.h"
|
||||||
#include "tensorflow/core/util/padding.h"
|
#include "tensorflow/core/util/padding.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef shape_inference::Dimension Dimension;
|
||||||
|
typedef shape_inference::InferenceContext InferenceContext;
|
||||||
|
typedef shape_inference::Shape Shape;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
|
||||||
|
int32* axis) {
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
|
||||||
|
if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
|
||||||
|
return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
|
||||||
|
-1 * rank_after_pack, ",", rank_after_pack,
|
||||||
|
")");
|
||||||
|
}
|
||||||
|
if (*axis < 0) *axis = (rank_after_pack + *axis);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
REGISTER_OP("Pack")
|
REGISTER_OP("Pack")
|
||||||
.Input("values: N * T")
|
.Input("values: N * T")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("N: int >= 1")
|
.Attr("N: int >= 1")
|
||||||
.Attr("T: type")
|
.Attr("T: type")
|
||||||
.Attr("axis: int = 0")
|
.Attr("axis: int = 0")
|
||||||
|
.SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
|
||||||
|
// Validate shapes of all inputs are compatible
|
||||||
|
const Shape* cur = c->input(c->num_inputs() - 1);
|
||||||
|
for (int i = c->num_inputs() - 2; i >= 0; --i) {
|
||||||
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
|
||||||
|
"From merging shape ", i,
|
||||||
|
" with other shapes.");
|
||||||
|
}
|
||||||
|
if (!c->RankKnown(cur)) {
|
||||||
|
c->set_output(0, c->CreateUnknownShape());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// Determine the axis that will be added, converting from negative
|
||||||
|
// axes to a positive point per negative indexing rules.
|
||||||
|
int32 rank = c->Rank(cur);
|
||||||
|
int32 axis;
|
||||||
|
TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
|
||||||
|
|
||||||
|
// Copy all dimensions over, inserting a dimension of value #inputs
|
||||||
|
// at <axis>.
|
||||||
|
std::vector<const Dimension*> dims;
|
||||||
|
int index = 0;
|
||||||
|
while (index < axis) dims.push_back(c->Dim(cur, index++));
|
||||||
|
dims.push_back(c->CreateDim(c->num_inputs()));
|
||||||
|
while (index < rank) dims.push_back(c->Dim(cur, index++));
|
||||||
|
|
||||||
|
c->set_output(0, c->CreateShape(dims));
|
||||||
|
return Status::OK();
|
||||||
|
}))
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
|
Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
|
||||||
|
|
||||||
@ -61,6 +111,29 @@ REGISTER_OP("Unpack")
|
|||||||
.Attr("num: int >= 0")
|
.Attr("num: int >= 0")
|
||||||
.Attr("T: type")
|
.Attr("T: type")
|
||||||
.Attr("axis: int = 0")
|
.Attr("axis: int = 0")
|
||||||
|
.SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
|
||||||
|
const Shape* s = c->input(0);
|
||||||
|
const Shape* out;
|
||||||
|
if (c->RankKnown(s)) {
|
||||||
|
// Determine the axis that will be removed, converting from negative
|
||||||
|
// axes to a positive point per negative indexing rules.
|
||||||
|
int32 rank = c->Rank(s);
|
||||||
|
int32 axis;
|
||||||
|
TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
|
||||||
|
|
||||||
|
// Copy all dimensions, removing the <axis> dimension.
|
||||||
|
std::vector<const Dimension*> dims;
|
||||||
|
for (int i = 0; i < rank; ++i) {
|
||||||
|
if (i != axis) dims.push_back(c->Dim(s, i));
|
||||||
|
}
|
||||||
|
out = c->CreateShape(dims);
|
||||||
|
} else {
|
||||||
|
// All outputs are the same shape, but it's not known.
|
||||||
|
out = c->CreateUnknownShape();
|
||||||
|
}
|
||||||
|
for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
|
||||||
|
return Status::OK();
|
||||||
|
}))
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
|
Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
|
||||||
|
|
||||||
@ -154,6 +227,18 @@ REGISTER_OP("Const")
|
|||||||
.Output("output: dtype")
|
.Output("output: dtype")
|
||||||
.Attr("value: tensor")
|
.Attr("value: tensor")
|
||||||
.Attr("dtype: type")
|
.Attr("dtype: type")
|
||||||
|
.SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
|
||||||
|
const TensorProto* proto = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
|
||||||
|
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
|
||||||
|
TensorShape shape(proto->tensor_shape());
|
||||||
|
std::vector<const Dimension*> dims;
|
||||||
|
for (int i = 0; i < shape.dims(); ++i) {
|
||||||
|
dims.push_back(c->CreateDim(shape.dim_size(i)));
|
||||||
|
}
|
||||||
|
c->set_output(0, c->CreateShape(dims));
|
||||||
|
return Status::OK();
|
||||||
|
}))
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Returns a constant tensor.
|
Returns a constant tensor.
|
||||||
|
|
||||||
|
137
tensorflow/core/ops/array_ops_test.cc
Normal file
137
tensorflow/core/ops/array_ops_test.cc
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference_testutil.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TEST(ArrayOpsTest, Pack_ShapeFn) {
|
||||||
|
std::unique_ptr<NodeDef> def_storage(new NodeDef);
|
||||||
|
NodeDef* def = def_storage.get();
|
||||||
|
auto set_axis = [def](int axis) {
|
||||||
|
TF_CHECK_OK(NodeDefBuilder("test", "Pack")
|
||||||
|
.Input({{"a", 0, DT_FLOAT}})
|
||||||
|
.Attr("axis", axis)
|
||||||
|
.Finalize(def));
|
||||||
|
};
|
||||||
|
const char op[] = "Pack";
|
||||||
|
|
||||||
|
set_axis(0);
|
||||||
|
INFER_OK_WITH_DEF(op, def, "?;?;?", "?");
|
||||||
|
|
||||||
|
for (int axis : {0, -3}) {
|
||||||
|
set_axis(axis);
|
||||||
|
INFER_OK_WITH_DEF(op, def, "?;?", "?");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[3,d0_0|d1_0,d0_1|d1_1]");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[3,d1_0,d0_1|d1_1]");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[3,d1_0,d1_1]");
|
||||||
|
}
|
||||||
|
for (int axis : {1, -2}) {
|
||||||
|
set_axis(axis);
|
||||||
|
INFER_OK_WITH_DEF(op, def, "?;?", "?");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[d0_0|d1_0,3,d0_1|d1_1]");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[d1_0,3,d0_1|d1_1]");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[d1_0,3,d1_1]");
|
||||||
|
}
|
||||||
|
for (int axis : {2, -1}) {
|
||||||
|
set_axis(axis);
|
||||||
|
INFER_OK_WITH_DEF(op, def, "?;?", "?");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[d0_0|d1_0,d0_1|d1_1,3]");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[d1_0,d0_1|d1_1,3]");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[d1_0,d1_1,3]");
|
||||||
|
}
|
||||||
|
|
||||||
|
set_axis(-4);
|
||||||
|
INFER_ERROR_WITH_DEF("Invalid axis: -4; must be in [-3,3)", op, def,
|
||||||
|
"[1,3];[1,3];?");
|
||||||
|
set_axis(3);
|
||||||
|
INFER_ERROR_WITH_DEF("Invalid axis: 3; must be in [-3,3)", op, def,
|
||||||
|
"[1,3];[1,3];?");
|
||||||
|
|
||||||
|
set_axis(0);
|
||||||
|
INFER_ERROR_WITH_DEF(("Shapes must be equal rank, but are 3 and 2"
|
||||||
|
"\n\tFrom merging shape 0 with other shapes."),
|
||||||
|
op, def, "[1,2,3];?;[1,4]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ArrayOpsTest, UnPack_ShapeFn) {
|
||||||
|
std::unique_ptr<NodeDef> def_storage(new NodeDef);
|
||||||
|
NodeDef* def = def_storage.get();
|
||||||
|
auto set_axis = [def](int axis) {
|
||||||
|
TF_CHECK_OK(NodeDefBuilder("test", "Unpack")
|
||||||
|
.Input("a", 0, DT_FLOAT)
|
||||||
|
.Attr("axis", axis)
|
||||||
|
.Finalize(def));
|
||||||
|
};
|
||||||
|
const char op[] = "Unpack";
|
||||||
|
|
||||||
|
set_axis(0);
|
||||||
|
INFER_OK_WITH_DEF(op, def, "?;?;?", "?");
|
||||||
|
|
||||||
|
for (int axis : {0, -3}) {
|
||||||
|
set_axis(axis);
|
||||||
|
INFER_OK_WITH_DEF(op, def, "?", "?");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_1,d0_2]");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_1,d0_2]");
|
||||||
|
}
|
||||||
|
for (int axis : {1, -2}) {
|
||||||
|
set_axis(axis);
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_0,d0_2]");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_0,d0_2]");
|
||||||
|
}
|
||||||
|
for (int axis : {2, -1}) {
|
||||||
|
set_axis(axis);
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_0,d0_1]");
|
||||||
|
INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_0,d0_1]");
|
||||||
|
}
|
||||||
|
|
||||||
|
set_axis(-4);
|
||||||
|
INFER_ERROR_WITH_DEF("Invalid axis: -4; must be in [-3,3)", op, def,
|
||||||
|
"[1,2,3]");
|
||||||
|
set_axis(3);
|
||||||
|
INFER_ERROR_WITH_DEF("Invalid axis: 3; must be in [-3,3)", op, def,
|
||||||
|
"[1,2,3]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ArrayOpsTest, Const_ShapeFn) {
|
||||||
|
std::unique_ptr<NodeDef> def_storage(new NodeDef);
|
||||||
|
NodeDef* def = def_storage.get();
|
||||||
|
TensorProto tensor_proto;
|
||||||
|
auto* shape_proto = tensor_proto.mutable_tensor_shape();
|
||||||
|
auto rebuild_node_def = [def, &tensor_proto]() {
|
||||||
|
TF_CHECK_OK(NodeDefBuilder("test", "Const")
|
||||||
|
.Attr("value", tensor_proto)
|
||||||
|
.Finalize(def));
|
||||||
|
};
|
||||||
|
const char op[] = "Const";
|
||||||
|
|
||||||
|
TensorShape{}.AsProto(shape_proto);
|
||||||
|
rebuild_node_def();
|
||||||
|
INFER_OK_WITH_DEF(op, def, "", "[]");
|
||||||
|
TensorShape{1, 2, 3, 4}.AsProto(shape_proto);
|
||||||
|
rebuild_node_def();
|
||||||
|
INFER_OK_WITH_DEF(op, def, "", "[1,2,3,4]");
|
||||||
|
|
||||||
|
shape_proto->add_dim()->set_size(-1);
|
||||||
|
rebuild_node_def();
|
||||||
|
INFER_ERROR_WITH_DEF("Shape [1,2,3,4,-1] has negative dimensions", op, def,
|
||||||
|
"");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
Loading…
Reference in New Issue
Block a user