Add C++ shape inference for Pack, Unpack, and Const.
Add GetAttr to shape_inference::InferenceContext. Allow setting NodeDef in shape_inference_testutil INFER calls (with new INFER*_WITH_DEF macro). Fix a bug that caused a crash when an INFER..ERROR macro called a shape inference function that did not return an error. Change: 126350221
This commit is contained in:
parent
b5c493301a
commit
605aa53d2b
@ -1812,10 +1812,13 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "ops/math_ops_test",
|
||||
tf_cc_tests(
|
||||
size = "small",
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tests = [
|
||||
"ops/array_ops_test.cc",
|
||||
"ops/math_ops_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":core",
|
||||
":core_cpu",
|
||||
|
@ -25,9 +25,9 @@ constexpr int32 InferenceContext::kUnknownRank;
|
||||
constexpr int64 InferenceContext::kUnknownDim;
|
||||
|
||||
InferenceContext::InferenceContext(
|
||||
const std::vector<string>& input_shapes, int num_outputs,
|
||||
const std::vector<const Tensor*>& input_tensors)
|
||||
: input_tensors_(input_tensors) {
|
||||
const NodeDef* node_def, const std::vector<string>& input_shapes,
|
||||
int num_outputs, const std::vector<const Tensor*>& input_tensors)
|
||||
: input_tensors_(input_tensors), node_def_(*CHECK_NOTNULL(node_def)) {
|
||||
for (const string& spec : input_shapes) {
|
||||
if (spec == "?") {
|
||||
inputs_.push_back(CreateUnknownShape());
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
@ -80,7 +82,10 @@ class InferenceContext {
|
||||
// the same Dimension*.
|
||||
//
|
||||
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
|
||||
InferenceContext(const std::vector<string>& input_shapes, int num_outputs,
|
||||
//
|
||||
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
|
||||
InferenceContext(const NodeDef* node_def,
|
||||
const std::vector<string>& input_shapes, int num_outputs,
|
||||
const std::vector<const Tensor*>& input_tensors = {});
|
||||
~InferenceContext();
|
||||
|
||||
@ -162,6 +167,12 @@ class InferenceContext {
|
||||
const Dimension* CreateDim(int64 value);
|
||||
const Dimension* CreateUnknownDim();
|
||||
|
||||
// Look up the attr for the NodeDef being evaluated with name attr_name and
|
||||
// set *value to its value. If no attr with attr_name is found in def(), or
|
||||
// the attr does not have a matching type, a non-ok status will be returned.
|
||||
template <class T>
|
||||
Status GetAttr(StringPiece attr_name, T* value) const;
|
||||
|
||||
private:
|
||||
Status ReturnUnknownShape(const Shape** out) {
|
||||
*out = CreateUnknownShape();
|
||||
@ -181,9 +192,14 @@ class InferenceContext {
|
||||
std::vector<const Tensor*> input_tensors_;
|
||||
std::vector<const Shape*> outputs_;
|
||||
|
||||
const NodeDef& node_def_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Template and inline method implementations, please ignore
|
||||
|
||||
inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
|
||||
inline Dimension::Dimension(int64 value) : value_(value) {}
|
||||
|
||||
@ -191,6 +207,11 @@ inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
|
||||
inline Shape::Shape(const std::vector<const Dimension*> dims)
|
||||
: rank_(dims.size()), dims_(dims) {}
|
||||
|
||||
template <class T>
|
||||
Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
|
||||
return GetNodeAttr(node_def_, attr_name, value);
|
||||
}
|
||||
|
||||
} // namespace shape_inference
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -21,7 +23,8 @@ namespace tensorflow {
|
||||
namespace shape_inference {
|
||||
|
||||
TEST(ShapeInferenceTest, RankAndDimInspection) {
|
||||
InferenceContext c({"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(2, c.num_outputs());
|
||||
|
||||
@ -54,7 +57,8 @@ TEST(ShapeInferenceTest, RankAndDimInspection) {
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, WithRank) {
|
||||
InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -91,7 +95,8 @@ TEST(ShapeInferenceTest, WithRank) {
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, WithRankAtLeast) {
|
||||
InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -125,7 +130,8 @@ TEST(ShapeInferenceTest, WithRankAtLeast) {
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, WithValue) {
|
||||
InferenceContext c({"[1,?]"}, 2 /* num_outputs */);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {"[1,?]"}, 2 /* num_outputs */);
|
||||
|
||||
auto d0 = c.Dim(c.input(0), 0);
|
||||
auto d1 = c.Dim(c.input(0), 1);
|
||||
@ -163,7 +169,8 @@ TEST(ShapeInferenceTest, WithValue) {
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, MergeDim) {
|
||||
InferenceContext c({"[2,?,2,1,?]"}, 2 /* num_outputs */);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {"[2,?,2,1,?]"}, 2 /* num_outputs */);
|
||||
|
||||
auto d2 = c.Dim(c.input(0), 0);
|
||||
auto d_unknown = c.Dim(c.input(0), 1);
|
||||
@ -202,7 +209,9 @@ TEST(ShapeInferenceTest, MergeDim) {
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, MergeShape) {
|
||||
InferenceContext c({"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
|
||||
NodeDef def;
|
||||
InferenceContext c(&def,
|
||||
{"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
|
||||
2 /* num_outputs */);
|
||||
|
||||
auto s_unknown = c.input(0);
|
||||
@ -260,7 +269,8 @@ TEST(ShapeInferenceTest, MergeShape) {
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, Subshape) {
|
||||
InferenceContext c({"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
|
||||
|
||||
const Shape* unknown = c.input(1);
|
||||
const Shape* out;
|
||||
@ -297,7 +307,8 @@ TEST(ShapeInferenceTest, Subshape) {
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, Concatenate) {
|
||||
InferenceContext c({"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -322,7 +333,8 @@ TEST(ShapeInferenceTest, Concatenate) {
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, CreateShape) {
|
||||
InferenceContext c({"[1,2,3,?,5]"}, 2 /* num_outputs */);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {"[1,2,3,?,5]"}, 2 /* num_outputs */);
|
||||
|
||||
std::vector<const Dimension*> dims;
|
||||
auto in0 = c.input(0);
|
||||
@ -341,7 +353,8 @@ TEST(ShapeInferenceTest, CreateShape) {
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, CreateUnknownShape) {
|
||||
InferenceContext c({}, 2 /* num_outputs */);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {}, 2 /* num_outputs */);
|
||||
|
||||
auto u0 = c.CreateUnknownShape();
|
||||
auto u1 = c.CreateUnknownShape();
|
||||
@ -352,7 +365,8 @@ TEST(ShapeInferenceTest, CreateUnknownShape) {
|
||||
|
||||
TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
|
||||
auto create = [](Tensor* t) {
|
||||
InferenceContext c({"?"}, 0 /* num_outputs */, {t});
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {"?"}, 0 /* num_outputs */, {t});
|
||||
const Shape* out;
|
||||
Status s = c.CreateShapeFromShapeTensor(0, &out);
|
||||
if (s.ok()) {
|
||||
@ -386,7 +400,8 @@ TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, CreateDim) {
|
||||
InferenceContext c({}, 2 /* num_outputs */);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {}, 2 /* num_outputs */);
|
||||
|
||||
auto* d0 = c.CreateDim(1);
|
||||
auto* d1 = c.CreateDim(1);
|
||||
@ -398,7 +413,8 @@ TEST(ShapeInferenceTest, CreateDim) {
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, CreateUnknownDim) {
|
||||
InferenceContext c({}, 2 /* num_outputs */);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {}, 2 /* num_outputs */);
|
||||
|
||||
auto* d0 = c.CreateUnknownDim();
|
||||
auto* d1 = c.CreateUnknownDim();
|
||||
@ -410,12 +426,29 @@ TEST(ShapeInferenceTest, CreateUnknownDim) {
|
||||
TEST(ShapeInferenceTest, InputTensors) {
|
||||
const Tensor t1 = tensorflow::test::AsTensor<float>({10});
|
||||
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
|
||||
InferenceContext c({"[1]", "[2]", "[3]"}, 2 /* num_outputs */, {&t1, &t2});
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, {"[1]", "[2]", "[3]"}, 2 /* num_outputs */,
|
||||
{&t1, &t2});
|
||||
|
||||
EXPECT_TRUE(c.input_tensor(0) == &t1);
|
||||
EXPECT_TRUE(c.input_tensor(1) == &t2);
|
||||
EXPECT_TRUE(c.input_tensor(2) == nullptr);
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, GetAttr) {
|
||||
OpRegistrationData op_reg_data;
|
||||
CHECK(OpDefBuilder("dummy").Attr("foo:string").Finalize(&op_reg_data).ok());
|
||||
NodeDef def;
|
||||
CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
|
||||
.Attr("foo", "bar")
|
||||
.Finalize(&def)
|
||||
.ok());
|
||||
|
||||
InferenceContext c(&def, {}, 2 /* num_outputs */);
|
||||
string value;
|
||||
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
|
||||
EXPECT_EQ("bar", value);
|
||||
}
|
||||
|
||||
} // namespace shape_inference
|
||||
} // namespace tensorflow
|
||||
|
@ -29,13 +29,18 @@ using shape_inference::Shape;
|
||||
using errors::Unknown;
|
||||
|
||||
Status InferShapes(const string& op_name, const string& ins,
|
||||
const string& expected_outs) {
|
||||
const string& expected_outs, const NodeDef* node_def) {
|
||||
const OpRegistrationData* op_reg_data;
|
||||
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op_name, &op_reg_data));
|
||||
const int num_outputs = op_reg_data->op_def.output_arg_size();
|
||||
|
||||
std::vector<string> ins_v = str_util::Split(ins, ';');
|
||||
shape_inference::InferenceContext c(ins_v, num_outputs);
|
||||
std::unique_ptr<const NodeDef> new_node_def;
|
||||
if (node_def == nullptr) {
|
||||
new_node_def.reset(new NodeDef);
|
||||
node_def = new_node_def.get();
|
||||
}
|
||||
shape_inference::InferenceContext c(node_def, ins_v, num_outputs);
|
||||
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c));
|
||||
|
||||
std::unordered_map<const Dimension*, std::pair<int, int>>
|
||||
|
@ -23,6 +23,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class NodeDef;
|
||||
|
||||
// Run shape inference for <op_name>, given inputs specified by <ins>
|
||||
// and returns an error if the inferred shape does not match expected_outs.
|
||||
//
|
||||
@ -45,11 +47,16 @@ namespace tensorflow {
|
||||
// <expected_outs> can be "e"; this is used to indicate that shape inference
|
||||
// should have failed.
|
||||
Status InferShapes(const string& op_name, const string& ins,
|
||||
const string& expected_outs);
|
||||
const string& expected_outs,
|
||||
const NodeDef* node_def = nullptr);
|
||||
|
||||
#define INFER_OK(op, i, o) EXPECT_EQ("", InferShapes(op, i, o).error_message())
|
||||
#define INFER_ERROR(s, op, i) \
|
||||
EXPECT_EQ(s, InferShapes(op, i, "x").error_message())
|
||||
EXPECT_EQ(s, InferShapes(op, i, "e").error_message())
|
||||
#define INFER_OK_WITH_DEF(op, nd, i, o) \
|
||||
EXPECT_EQ("", InferShapes(op, i, o, nd).error_message())
|
||||
#define INFER_ERROR_WITH_DEF(s, op, nd, i) \
|
||||
EXPECT_EQ(s, InferShapes(op, i, "e", nd).error_message())
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -14,17 +14,67 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/util/mirror_pad_mode.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef shape_inference::Dimension Dimension;
|
||||
typedef shape_inference::InferenceContext InferenceContext;
|
||||
typedef shape_inference::Shape Shape;
|
||||
|
||||
namespace {
|
||||
|
||||
Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
|
||||
int32* axis) {
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
|
||||
if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
|
||||
return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
|
||||
-1 * rank_after_pack, ",", rank_after_pack,
|
||||
")");
|
||||
}
|
||||
if (*axis < 0) *axis = (rank_after_pack + *axis);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_OP("Pack")
|
||||
.Input("values: N * T")
|
||||
.Output("output: T")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("T: type")
|
||||
.Attr("axis: int = 0")
|
||||
.SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
|
||||
// Validate shapes of all inputs are compatible
|
||||
const Shape* cur = c->input(c->num_inputs() - 1);
|
||||
for (int i = c->num_inputs() - 2; i >= 0; --i) {
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
|
||||
"From merging shape ", i,
|
||||
" with other shapes.");
|
||||
}
|
||||
if (!c->RankKnown(cur)) {
|
||||
c->set_output(0, c->CreateUnknownShape());
|
||||
return Status::OK();
|
||||
}
|
||||
// Determine the axis that will be added, converting from negative
|
||||
// axes to a positive point per negative indexing rules.
|
||||
int32 rank = c->Rank(cur);
|
||||
int32 axis;
|
||||
TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
|
||||
|
||||
// Copy all dimensions over, inserting a dimension of value #inputs
|
||||
// at <axis>.
|
||||
std::vector<const Dimension*> dims;
|
||||
int index = 0;
|
||||
while (index < axis) dims.push_back(c->Dim(cur, index++));
|
||||
dims.push_back(c->CreateDim(c->num_inputs()));
|
||||
while (index < rank) dims.push_back(c->Dim(cur, index++));
|
||||
|
||||
c->set_output(0, c->CreateShape(dims));
|
||||
return Status::OK();
|
||||
}))
|
||||
.Doc(R"doc(
|
||||
Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
|
||||
|
||||
@ -61,6 +111,29 @@ REGISTER_OP("Unpack")
|
||||
.Attr("num: int >= 0")
|
||||
.Attr("T: type")
|
||||
.Attr("axis: int = 0")
|
||||
.SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
|
||||
const Shape* s = c->input(0);
|
||||
const Shape* out;
|
||||
if (c->RankKnown(s)) {
|
||||
// Determine the axis that will be removed, converting from negative
|
||||
// axes to a positive point per negative indexing rules.
|
||||
int32 rank = c->Rank(s);
|
||||
int32 axis;
|
||||
TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
|
||||
|
||||
// Copy all dimensions, removing the <axis> dimension.
|
||||
std::vector<const Dimension*> dims;
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
if (i != axis) dims.push_back(c->Dim(s, i));
|
||||
}
|
||||
out = c->CreateShape(dims);
|
||||
} else {
|
||||
// All outputs are the same shape, but it's not known.
|
||||
out = c->CreateUnknownShape();
|
||||
}
|
||||
for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
|
||||
return Status::OK();
|
||||
}))
|
||||
.Doc(R"doc(
|
||||
Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
|
||||
|
||||
@ -154,6 +227,18 @@ REGISTER_OP("Const")
|
||||
.Output("output: dtype")
|
||||
.Attr("value: tensor")
|
||||
.Attr("dtype: type")
|
||||
.SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
|
||||
const TensorProto* proto = nullptr;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
|
||||
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
|
||||
TensorShape shape(proto->tensor_shape());
|
||||
std::vector<const Dimension*> dims;
|
||||
for (int i = 0; i < shape.dims(); ++i) {
|
||||
dims.push_back(c->CreateDim(shape.dim_size(i)));
|
||||
}
|
||||
c->set_output(0, c->CreateShape(dims));
|
||||
return Status::OK();
|
||||
}))
|
||||
.Doc(R"doc(
|
||||
Returns a constant tensor.
|
||||
|
||||
|
137
tensorflow/core/ops/array_ops_test.cc
Normal file
137
tensorflow/core/ops/array_ops_test.cc
Normal file
@ -0,0 +1,137 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference_testutil.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(ArrayOpsTest, Pack_ShapeFn) {
|
||||
std::unique_ptr<NodeDef> def_storage(new NodeDef);
|
||||
NodeDef* def = def_storage.get();
|
||||
auto set_axis = [def](int axis) {
|
||||
TF_CHECK_OK(NodeDefBuilder("test", "Pack")
|
||||
.Input({{"a", 0, DT_FLOAT}})
|
||||
.Attr("axis", axis)
|
||||
.Finalize(def));
|
||||
};
|
||||
const char op[] = "Pack";
|
||||
|
||||
set_axis(0);
|
||||
INFER_OK_WITH_DEF(op, def, "?;?;?", "?");
|
||||
|
||||
for (int axis : {0, -3}) {
|
||||
set_axis(axis);
|
||||
INFER_OK_WITH_DEF(op, def, "?;?", "?");
|
||||
INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[3,d0_0|d1_0,d0_1|d1_1]");
|
||||
INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[3,d1_0,d0_1|d1_1]");
|
||||
INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[3,d1_0,d1_1]");
|
||||
}
|
||||
for (int axis : {1, -2}) {
|
||||
set_axis(axis);
|
||||
INFER_OK_WITH_DEF(op, def, "?;?", "?");
|
||||
INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[d0_0|d1_0,3,d0_1|d1_1]");
|
||||
INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[d1_0,3,d0_1|d1_1]");
|
||||
INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[d1_0,3,d1_1]");
|
||||
}
|
||||
for (int axis : {2, -1}) {
|
||||
set_axis(axis);
|
||||
INFER_OK_WITH_DEF(op, def, "?;?", "?");
|
||||
INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[d0_0|d1_0,d0_1|d1_1,3]");
|
||||
INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[d1_0,d0_1|d1_1,3]");
|
||||
INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[d1_0,d1_1,3]");
|
||||
}
|
||||
|
||||
set_axis(-4);
|
||||
INFER_ERROR_WITH_DEF("Invalid axis: -4; must be in [-3,3)", op, def,
|
||||
"[1,3];[1,3];?");
|
||||
set_axis(3);
|
||||
INFER_ERROR_WITH_DEF("Invalid axis: 3; must be in [-3,3)", op, def,
|
||||
"[1,3];[1,3];?");
|
||||
|
||||
set_axis(0);
|
||||
INFER_ERROR_WITH_DEF(("Shapes must be equal rank, but are 3 and 2"
|
||||
"\n\tFrom merging shape 0 with other shapes."),
|
||||
op, def, "[1,2,3];?;[1,4]");
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, UnPack_ShapeFn) {
|
||||
std::unique_ptr<NodeDef> def_storage(new NodeDef);
|
||||
NodeDef* def = def_storage.get();
|
||||
auto set_axis = [def](int axis) {
|
||||
TF_CHECK_OK(NodeDefBuilder("test", "Unpack")
|
||||
.Input("a", 0, DT_FLOAT)
|
||||
.Attr("axis", axis)
|
||||
.Finalize(def));
|
||||
};
|
||||
const char op[] = "Unpack";
|
||||
|
||||
set_axis(0);
|
||||
INFER_OK_WITH_DEF(op, def, "?;?;?", "?");
|
||||
|
||||
for (int axis : {0, -3}) {
|
||||
set_axis(axis);
|
||||
INFER_OK_WITH_DEF(op, def, "?", "?");
|
||||
INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_1,d0_2]");
|
||||
INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_1,d0_2]");
|
||||
}
|
||||
for (int axis : {1, -2}) {
|
||||
set_axis(axis);
|
||||
INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_0,d0_2]");
|
||||
INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_0,d0_2]");
|
||||
}
|
||||
for (int axis : {2, -1}) {
|
||||
set_axis(axis);
|
||||
INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_0,d0_1]");
|
||||
INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_0,d0_1]");
|
||||
}
|
||||
|
||||
set_axis(-4);
|
||||
INFER_ERROR_WITH_DEF("Invalid axis: -4; must be in [-3,3)", op, def,
|
||||
"[1,2,3]");
|
||||
set_axis(3);
|
||||
INFER_ERROR_WITH_DEF("Invalid axis: 3; must be in [-3,3)", op, def,
|
||||
"[1,2,3]");
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, Const_ShapeFn) {
|
||||
std::unique_ptr<NodeDef> def_storage(new NodeDef);
|
||||
NodeDef* def = def_storage.get();
|
||||
TensorProto tensor_proto;
|
||||
auto* shape_proto = tensor_proto.mutable_tensor_shape();
|
||||
auto rebuild_node_def = [def, &tensor_proto]() {
|
||||
TF_CHECK_OK(NodeDefBuilder("test", "Const")
|
||||
.Attr("value", tensor_proto)
|
||||
.Finalize(def));
|
||||
};
|
||||
const char op[] = "Const";
|
||||
|
||||
TensorShape{}.AsProto(shape_proto);
|
||||
rebuild_node_def();
|
||||
INFER_OK_WITH_DEF(op, def, "", "[]");
|
||||
TensorShape{1, 2, 3, 4}.AsProto(shape_proto);
|
||||
rebuild_node_def();
|
||||
INFER_OK_WITH_DEF(op, def, "", "[1,2,3,4]");
|
||||
|
||||
shape_proto->add_dim()->set_size(-1);
|
||||
rebuild_node_def();
|
||||
INFER_ERROR_WITH_DEF("Shape [1,2,3,4,-1] has negative dimensions", op, def,
|
||||
"");
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
Loading…
Reference in New Issue
Block a user