From 605aa53d2bb65a8a38dc72725e28ebe75a949d5a Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 30 Jun 2016 14:21:08 -0800
Subject: [PATCH] Add C++ shape inference for Pack, Unpack, and Const. Add
 GetAttr to shape_inference::InferenceContext. Allow setting NodeDef in
 shape_inference_testutil INFER calls (with new INFER*_WITH_DEF macro).  Fix a
 bug that caused a crash when an INFER..ERROR macro called a shape inference
 function that did not return an error. Change: 126350221

---
 tensorflow/core/BUILD                         |   7 +-
 tensorflow/core/framework/shape_inference.cc  |   6 +-
 tensorflow/core/framework/shape_inference.h   |  23 ++-
 .../core/framework/shape_inference_test.cc    |  61 ++++++--
 .../framework/shape_inference_testutil.cc     |   9 +-
 .../core/framework/shape_inference_testutil.h |  11 +-
 tensorflow/core/ops/array_ops.cc              |  85 +++++++++++
 tensorflow/core/ops/array_ops_test.cc         | 137 ++++++++++++++++++
 8 files changed, 315 insertions(+), 24 deletions(-)
 create mode 100644 tensorflow/core/ops/array_ops_test.cc

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index b684522eb6d..b2a928867f4 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1812,10 +1812,13 @@ tf_cc_test(
     ],
 )
 
-tf_cc_test(
-    name = "ops/math_ops_test",
+tf_cc_tests(
     size = "small",
     linkstatic = tf_kernel_tests_linkstatic(),
+    tests = [
+        "ops/array_ops_test.cc",
+        "ops/math_ops_test.cc",
+    ],
     deps = [
         ":core",
         ":core_cpu",
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 2df57d6cab3..bd8e6ea3094 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -25,9 +25,9 @@ constexpr int32 InferenceContext::kUnknownRank;
 constexpr int64 InferenceContext::kUnknownDim;
 
 InferenceContext::InferenceContext(
-    const std::vector<string>& input_shapes, int num_outputs,
-    const std::vector<const Tensor*>& input_tensors)
-    : input_tensors_(input_tensors) {
+    const NodeDef* node_def, const std::vector<string>& input_shapes,
+    int num_outputs, const std::vector<const Tensor*>& input_tensors)
+    : input_tensors_(input_tensors), node_def_(*CHECK_NOTNULL(node_def)) {
   for (const string& spec : input_shapes) {
     if (spec == "?") {
       inputs_.push_back(CreateUnknownShape());
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index bb6a66dc533..6385177bc19 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -17,6 +17,8 @@ limitations under the License.
 
 #include <vector>
 
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -80,7 +82,10 @@ class InferenceContext {
   //               the same Dimension*.
   //
   // <input_tensors> is NULL-padded to be the same size as <input_shapes>.
-  InferenceContext(const std::vector<string>& input_shapes, int num_outputs,
+  //
+  // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
+  InferenceContext(const NodeDef* node_def,
+                   const std::vector<string>& input_shapes, int num_outputs,
                    const std::vector<const Tensor*>& input_tensors = {});
   ~InferenceContext();
 
@@ -162,6 +167,12 @@ class InferenceContext {
   const Dimension* CreateDim(int64 value);
   const Dimension* CreateUnknownDim();
 
+  // Look up the attr for the NodeDef being evaluated with name attr_name and
+  // set *value to its value.  If no attr with attr_name is found in def(), or
+  // the attr does not have a matching type, a non-ok status will be returned.
+  template <class T>
+  Status GetAttr(StringPiece attr_name, T* value) const;
+
  private:
   Status ReturnUnknownShape(const Shape** out) {
     *out = CreateUnknownShape();
@@ -181,9 +192,14 @@ class InferenceContext {
   std::vector<const Tensor*> input_tensors_;
   std::vector<const Shape*> outputs_;
 
+  const NodeDef& node_def_;
+
   TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
 };
 
+// -----------------------------------------------------------------------------
+// Template and inline method implementations, please ignore
+
 inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
 inline Dimension::Dimension(int64 value) : value_(value) {}
 
@@ -191,6 +207,11 @@ inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
 inline Shape::Shape(const std::vector<const Dimension*> dims)
     : rank_(dims.size()), dims_(dims) {}
 
+template <class T>
+Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
+  return GetNodeAttr(node_def_, attr_name, value);
+}
+
 }  // namespace shape_inference
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index e4ca7645b2e..e52d1c5a2d6 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/core/framework/shape_inference.h"
 
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_def_builder.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -21,7 +23,8 @@ namespace tensorflow {
 namespace shape_inference {
 
 TEST(ShapeInferenceTest, RankAndDimInspection) {
-  InferenceContext c({"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
   EXPECT_EQ(3, c.num_inputs());
   EXPECT_EQ(2, c.num_outputs());
 
@@ -54,7 +57,8 @@ TEST(ShapeInferenceTest, RankAndDimInspection) {
 }
 
 TEST(ShapeInferenceTest, WithRank) {
-  InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
 
   auto in0 = c.input(0);
   auto in1 = c.input(1);
@@ -91,7 +95,8 @@ TEST(ShapeInferenceTest, WithRank) {
 }
 
 TEST(ShapeInferenceTest, WithRankAtLeast) {
-  InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
 
   auto in0 = c.input(0);
   auto in1 = c.input(1);
@@ -125,7 +130,8 @@ TEST(ShapeInferenceTest, WithRankAtLeast) {
 }
 
 TEST(ShapeInferenceTest, WithValue) {
-  InferenceContext c({"[1,?]"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"[1,?]"}, 2 /* num_outputs */);
 
   auto d0 = c.Dim(c.input(0), 0);
   auto d1 = c.Dim(c.input(0), 1);
@@ -163,7 +169,8 @@ TEST(ShapeInferenceTest, WithValue) {
 }
 
 TEST(ShapeInferenceTest, MergeDim) {
-  InferenceContext c({"[2,?,2,1,?]"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"[2,?,2,1,?]"}, 2 /* num_outputs */);
 
   auto d2 = c.Dim(c.input(0), 0);
   auto d_unknown = c.Dim(c.input(0), 1);
@@ -202,7 +209,9 @@ TEST(ShapeInferenceTest, MergeDim) {
 }
 
 TEST(ShapeInferenceTest, MergeShape) {
-  InferenceContext c({"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
+  NodeDef def;
+  InferenceContext c(&def,
+                     {"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
                      2 /* num_outputs */);
 
   auto s_unknown = c.input(0);
@@ -260,7 +269,8 @@ TEST(ShapeInferenceTest, MergeShape) {
 }
 
 TEST(ShapeInferenceTest, Subshape) {
-  InferenceContext c({"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
 
   const Shape* unknown = c.input(1);
   const Shape* out;
@@ -297,7 +307,8 @@ TEST(ShapeInferenceTest, Subshape) {
 }
 
 TEST(ShapeInferenceTest, Concatenate) {
-  InferenceContext c({"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
 
   auto in0 = c.input(0);
   auto in1 = c.input(1);
@@ -322,7 +333,8 @@ TEST(ShapeInferenceTest, Concatenate) {
 }
 
 TEST(ShapeInferenceTest, CreateShape) {
-  InferenceContext c({"[1,2,3,?,5]"}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {"[1,2,3,?,5]"}, 2 /* num_outputs */);
 
   std::vector<const Dimension*> dims;
   auto in0 = c.input(0);
@@ -341,7 +353,8 @@ TEST(ShapeInferenceTest, CreateShape) {
 }
 
 TEST(ShapeInferenceTest, CreateUnknownShape) {
-  InferenceContext c({}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {}, 2 /* num_outputs */);
 
   auto u0 = c.CreateUnknownShape();
   auto u1 = c.CreateUnknownShape();
@@ -352,7 +365,8 @@ TEST(ShapeInferenceTest, CreateUnknownShape) {
 
 TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
   auto create = [](Tensor* t) {
-    InferenceContext c({"?"}, 0 /* num_outputs */, {t});
+    NodeDef def;
+    InferenceContext c(&def, {"?"}, 0 /* num_outputs */, {t});
     const Shape* out;
     Status s = c.CreateShapeFromShapeTensor(0, &out);
     if (s.ok()) {
@@ -386,7 +400,8 @@ TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
 }
 
 TEST(ShapeInferenceTest, CreateDim) {
-  InferenceContext c({}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {}, 2 /* num_outputs */);
 
   auto* d0 = c.CreateDim(1);
   auto* d1 = c.CreateDim(1);
@@ -398,7 +413,8 @@ TEST(ShapeInferenceTest, CreateDim) {
 }
 
 TEST(ShapeInferenceTest, CreateUnknownDim) {
-  InferenceContext c({}, 2 /* num_outputs */);
+  NodeDef def;
+  InferenceContext c(&def, {}, 2 /* num_outputs */);
 
   auto* d0 = c.CreateUnknownDim();
   auto* d1 = c.CreateUnknownDim();
@@ -410,12 +426,29 @@ TEST(ShapeInferenceTest, CreateUnknownDim) {
 TEST(ShapeInferenceTest, InputTensors) {
   const Tensor t1 = tensorflow::test::AsTensor<float>({10});
   const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
-  InferenceContext c({"[1]", "[2]", "[3]"}, 2 /* num_outputs */, {&t1, &t2});
+  NodeDef def;
+  InferenceContext c(&def, {"[1]", "[2]", "[3]"}, 2 /* num_outputs */,
+                     {&t1, &t2});
 
   EXPECT_TRUE(c.input_tensor(0) == &t1);
   EXPECT_TRUE(c.input_tensor(1) == &t2);
   EXPECT_TRUE(c.input_tensor(2) == nullptr);
 }
 
+TEST(ShapeInferenceTest, GetAttr) {
+  OpRegistrationData op_reg_data;
+  CHECK(OpDefBuilder("dummy").Attr("foo:string").Finalize(&op_reg_data).ok());
+  NodeDef def;
+  CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
+            .Attr("foo", "bar")
+            .Finalize(&def)
+            .ok());
+
+  InferenceContext c(&def, {}, 2 /* num_outputs */);
+  string value;
+  EXPECT_TRUE(c.GetAttr("foo", &value).ok());
+  EXPECT_EQ("bar", value);
+}
+
 }  // namespace shape_inference
 }  // namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc
index 9b56014edbe..f771e477644 100644
--- a/tensorflow/core/framework/shape_inference_testutil.cc
+++ b/tensorflow/core/framework/shape_inference_testutil.cc
@@ -29,13 +29,18 @@ using shape_inference::Shape;
 using errors::Unknown;
 
 Status InferShapes(const string& op_name, const string& ins,
-                   const string& expected_outs) {
+                   const string& expected_outs, const NodeDef* node_def) {
   const OpRegistrationData* op_reg_data;
   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op_name, &op_reg_data));
   const int num_outputs = op_reg_data->op_def.output_arg_size();
 
   std::vector<string> ins_v = str_util::Split(ins, ';');
-  shape_inference::InferenceContext c(ins_v, num_outputs);
+  std::unique_ptr<const NodeDef> new_node_def;
+  if (node_def == nullptr) {
+    new_node_def.reset(new NodeDef);
+    node_def = new_node_def.get();
+  }
+  shape_inference::InferenceContext c(node_def, ins_v, num_outputs);
   TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c));
 
   std::unordered_map<const Dimension*, std::pair<int, int>>
diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h
index f2581247d9e..221ec875fb0 100644
--- a/tensorflow/core/framework/shape_inference_testutil.h
+++ b/tensorflow/core/framework/shape_inference_testutil.h
@@ -23,6 +23,8 @@ limitations under the License.
 
 namespace tensorflow {
 
+class NodeDef;
+
 // Run shape inference for <op_name>, given inputs specified by <ins>
 // and returns an error if the inferred shape does not match expected_outs.
 //
@@ -45,11 +47,16 @@ namespace tensorflow {
 // <expected_outs> can be "e"; this is used to indicate that shape inference
 // should have failed.
 Status InferShapes(const string& op_name, const string& ins,
-                   const string& expected_outs);
+                   const string& expected_outs,
+                   const NodeDef* node_def = nullptr);
 
 #define INFER_OK(op, i, o) EXPECT_EQ("", InferShapes(op, i, o).error_message())
 #define INFER_ERROR(s, op, i) \
-  EXPECT_EQ(s, InferShapes(op, i, "x").error_message())
+  EXPECT_EQ(s, InferShapes(op, i, "e").error_message())
+#define INFER_OK_WITH_DEF(op, nd, i, o) \
+  EXPECT_EQ("", InferShapes(op, i, o, nd).error_message())
+#define INFER_ERROR_WITH_DEF(s, op, nd, i) \
+  EXPECT_EQ(s, InferShapes(op, i, "e", nd).error_message())
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index dc96588f73a..4ef3a48221a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -14,17 +14,67 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
 #include "tensorflow/core/util/mirror_pad_mode.h"
 #include "tensorflow/core/util/padding.h"
 
 namespace tensorflow {
 
+typedef shape_inference::Dimension Dimension;
+typedef shape_inference::InferenceContext InferenceContext;
+typedef shape_inference::Shape Shape;
+
+namespace {
+
+Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
+                               int32* axis) {
+  TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
+  if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
+    return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
+                                   -1 * rank_after_pack, ",", rank_after_pack,
+                                   ")");
+  }
+  if (*axis < 0) *axis = (rank_after_pack + *axis);
+  return Status::OK();
+}
+
+}  // namespace
+
 REGISTER_OP("Pack")
     .Input("values: N * T")
     .Output("output: T")
     .Attr("N: int >= 1")
     .Attr("T: type")
     .Attr("axis: int = 0")
+    .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+      // Validate shapes of all inputs are compatible
+      const Shape* cur = c->input(c->num_inputs() - 1);
+      for (int i = c->num_inputs() - 2; i >= 0; --i) {
+        TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
+                                        "From merging shape ", i,
+                                        " with other shapes.");
+      }
+      if (!c->RankKnown(cur)) {
+        c->set_output(0, c->CreateUnknownShape());
+        return Status::OK();
+      }
+      // Determine the axis that will be added, converting from negative
+      // axes to a positive point per negative indexing rules.
+      int32 rank = c->Rank(cur);
+      int32 axis;
+      TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
+
+      // Copy all dimensions over, inserting a dimension of value #inputs
+      // at <axis>.
+      std::vector<const Dimension*> dims;
+      int index = 0;
+      while (index < axis) dims.push_back(c->Dim(cur, index++));
+      dims.push_back(c->CreateDim(c->num_inputs()));
+      while (index < rank) dims.push_back(c->Dim(cur, index++));
+
+      c->set_output(0, c->CreateShape(dims));
+      return Status::OK();
+    }))
     .Doc(R"doc(
 Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
 
@@ -61,6 +111,29 @@ REGISTER_OP("Unpack")
     .Attr("num: int >= 0")
     .Attr("T: type")
     .Attr("axis: int = 0")
+    .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+      const Shape* s = c->input(0);
+      const Shape* out;
+      if (c->RankKnown(s)) {
+        // Determine the axis that will be removed, converting from negative
+        // axes to a positive point per negative indexing rules.
+        int32 rank = c->Rank(s);
+        int32 axis;
+        TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
+
+        // Copy all dimensions, removing the <axis> dimension.
+        std::vector<const Dimension*> dims;
+        for (int i = 0; i < rank; ++i) {
+          if (i != axis) dims.push_back(c->Dim(s, i));
+        }
+        out = c->CreateShape(dims);
+      } else {
+        // All outputs are the same shape, but it's not known.
+        out = c->CreateUnknownShape();
+      }
+      for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
+      return Status::OK();
+    }))
     .Doc(R"doc(
 Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
 
@@ -154,6 +227,18 @@ REGISTER_OP("Const")
     .Output("output: dtype")
     .Attr("value: tensor")
     .Attr("dtype: type")
+    .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+      const TensorProto* proto = nullptr;
+      TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
+      TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
+      TensorShape shape(proto->tensor_shape());
+      std::vector<const Dimension*> dims;
+      for (int i = 0; i < shape.dims(); ++i) {
+        dims.push_back(c->CreateDim(shape.dim_size(i)));
+      }
+      c->set_output(0, c->CreateShape(dims));
+      return Status::OK();
+    }))
     .Doc(R"doc(
 Returns a constant tensor.
 
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
new file mode 100644
index 00000000000..19dfa293584
--- /dev/null
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -0,0 +1,137 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(ArrayOpsTest, Pack_ShapeFn) {
+  std::unique_ptr<NodeDef> def_storage(new NodeDef);
+  NodeDef* def = def_storage.get();
+  auto set_axis = [def](int axis) {
+    TF_CHECK_OK(NodeDefBuilder("test", "Pack")
+                    .Input({{"a", 0, DT_FLOAT}})
+                    .Attr("axis", axis)
+                    .Finalize(def));
+  };
+  const char op[] = "Pack";
+
+  set_axis(0);
+  INFER_OK_WITH_DEF(op, def, "?;?;?", "?");
+
+  for (int axis : {0, -3}) {
+    set_axis(axis);
+    INFER_OK_WITH_DEF(op, def, "?;?", "?");
+    INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[3,d0_0|d1_0,d0_1|d1_1]");
+    INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[3,d1_0,d0_1|d1_1]");
+    INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[3,d1_0,d1_1]");
+  }
+  for (int axis : {1, -2}) {
+    set_axis(axis);
+    INFER_OK_WITH_DEF(op, def, "?;?", "?");
+    INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[d0_0|d1_0,3,d0_1|d1_1]");
+    INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[d1_0,3,d0_1|d1_1]");
+    INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[d1_0,3,d1_1]");
+  }
+  for (int axis : {2, -1}) {
+    set_axis(axis);
+    INFER_OK_WITH_DEF(op, def, "?;?", "?");
+    INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[d0_0|d1_0,d0_1|d1_1,3]");
+    INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[d1_0,d0_1|d1_1,3]");
+    INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[d1_0,d1_1,3]");
+  }
+
+  set_axis(-4);
+  INFER_ERROR_WITH_DEF("Invalid axis: -4; must be in [-3,3)", op, def,
+                       "[1,3];[1,3];?");
+  set_axis(3);
+  INFER_ERROR_WITH_DEF("Invalid axis: 3; must be in [-3,3)", op, def,
+                       "[1,3];[1,3];?");
+
+  set_axis(0);
+  INFER_ERROR_WITH_DEF(("Shapes must be equal rank, but are 3 and 2"
+                        "\n\tFrom merging shape 0 with other shapes."),
+                       op, def, "[1,2,3];?;[1,4]");
+}
+
+TEST(ArrayOpsTest, UnPack_ShapeFn) {
+  std::unique_ptr<NodeDef> def_storage(new NodeDef);
+  NodeDef* def = def_storage.get();
+  auto set_axis = [def](int axis) {
+    TF_CHECK_OK(NodeDefBuilder("test", "Unpack")
+                    .Input("a", 0, DT_FLOAT)
+                    .Attr("axis", axis)
+                    .Finalize(def));
+  };
+  const char op[] = "Unpack";
+
+  set_axis(0);
+  INFER_OK_WITH_DEF(op, def, "?;?;?", "?");
+
+  for (int axis : {0, -3}) {
+    set_axis(axis);
+    INFER_OK_WITH_DEF(op, def, "?", "?");
+    INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_1,d0_2]");
+    INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_1,d0_2]");
+  }
+  for (int axis : {1, -2}) {
+    set_axis(axis);
+    INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_0,d0_2]");
+    INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_0,d0_2]");
+  }
+  for (int axis : {2, -1}) {
+    set_axis(axis);
+    INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_0,d0_1]");
+    INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_0,d0_1]");
+  }
+
+  set_axis(-4);
+  INFER_ERROR_WITH_DEF("Invalid axis: -4; must be in [-3,3)", op, def,
+                       "[1,2,3]");
+  set_axis(3);
+  INFER_ERROR_WITH_DEF("Invalid axis: 3; must be in [-3,3)", op, def,
+                       "[1,2,3]");
+}
+
+TEST(ArrayOpsTest, Const_ShapeFn) {
+  std::unique_ptr<NodeDef> def_storage(new NodeDef);
+  NodeDef* def = def_storage.get();
+  TensorProto tensor_proto;
+  auto* shape_proto = tensor_proto.mutable_tensor_shape();
+  auto rebuild_node_def = [def, &tensor_proto]() {
+    TF_CHECK_OK(NodeDefBuilder("test", "Const")
+                    .Attr("value", tensor_proto)
+                    .Finalize(def));
+  };
+  const char op[] = "Const";
+
+  TensorShape{}.AsProto(shape_proto);
+  rebuild_node_def();
+  INFER_OK_WITH_DEF(op, def, "", "[]");
+  TensorShape{1, 2, 3, 4}.AsProto(shape_proto);
+  rebuild_node_def();
+  INFER_OK_WITH_DEF(op, def, "", "[1,2,3,4]");
+
+  shape_proto->add_dim()->set_size(-1);
+  rebuild_node_def();
+  INFER_ERROR_WITH_DEF("Shape [1,2,3,4,-1] has negative dimensions", op, def,
+                       "");
+}
+
+}  // end namespace tensorflow