From c697d96ec69333fcb5b2abe95a10cd53b68757fe Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 8 Dec 2017 17:10:40 -0800 Subject: [PATCH] Preserve symbolic shape information as much as possible during shape creation PiperOrigin-RevId: 178448208 --- tensorflow/core/framework/shape_inference.cc | 18 ++++++------------ .../core/framework/shape_inference_test.cc | 8 ++++---- tensorflow/core/ops/array_ops_test.cc | 6 +++--- tensorflow/core/ops/math_ops_test.cc | 2 +- tensorflow/core/ops/spectral_ops_test.cc | 6 +++--- 5 files changed, 17 insertions(+), 23 deletions(-) diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index ee9192d4a14..c13f13a126f 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -342,8 +342,8 @@ Status InferenceContext::WithRank(ShapeHandle shape, int64 rank, for (int i = 0; i < rank; ++i) { dims.push_back(UnknownDim()); } - *out = shape_manager_.MakeShape(dims); - return Status::OK(); + ShapeHandle shp = shape_manager_.MakeShape(dims); + return Merge(shape, shp, out); } *out = nullptr; @@ -357,13 +357,10 @@ Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank, return errors::InvalidArgument("Rank cannot exceed kint32max"); } const int32 existing = Rank(shape); - if (existing >= rank) { + if (existing >= rank || existing == kUnknownRank) { *out = shape; return Status::OK(); } - if (existing == kUnknownRank) { - return ReturnUnknownShape(out); - } *out = nullptr; return errors::InvalidArgument("Shape must be at least rank ", rank, " but is rank ", existing); @@ -375,10 +372,7 @@ Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank, return errors::InvalidArgument("Rank cannot exceed kint32max"); } const int32 existing = Rank(shape); - if (existing == kUnknownRank) { - return ReturnUnknownShape(out); - } - if (existing <= rank) { + if (existing <= rank || existing == kUnknownRank) { *out = shape; return Status::OK(); } @@ -395,8 +389,8 @@ Status InferenceContext::WithValue(DimensionHandle dim, int64 value, return Status::OK(); } if (existing == kUnknownDim) { - *out = MakeDim(value); - return Status::OK(); + DimensionHandle d = MakeDim(value); + return Merge(dim, d, out); } *out = nullptr; return errors::InvalidArgument("Dimension must be ", value, " but is ", diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index 68156e63ca7..a9b63ca60e4 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -359,11 +359,11 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) { // WithRankAtMost on a shape with unknown dimensionality always succeeds. EXPECT_TRUE(c.WithRankAtMost(in0, 1, &s1).ok()); EXPECT_EQ("?", c.DebugString(s1)); - EXPECT_FALSE(SameHandle(in0, s1)); + EXPECT_TRUE(SameHandle(in0, s1)); EXPECT_TRUE(c.WithRankAtMost(in0, 2, &s2).ok()); EXPECT_EQ("?", c.DebugString(s2)); - EXPECT_FALSE(SameHandle(s1, s2)); + EXPECT_TRUE(SameHandle(s1, s2)); // WithRankAtMost on shape with known dimensionality. s1 = in1; @@ -398,11 +398,11 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) { // WithRankAtLeast on a shape with unknown dimensionality always succeeds. EXPECT_TRUE(c.WithRankAtLeast(in0, 1, &s1).ok()); EXPECT_EQ("?", c.DebugString(s1)); - EXPECT_FALSE(SameHandle(in0, s1)); + EXPECT_TRUE(SameHandle(in0, s1)); EXPECT_TRUE(c.WithRankAtLeast(in0, 2, &s2).ok()); EXPECT_EQ("?", c.DebugString(s2)); - EXPECT_FALSE(SameHandle(s1, s2)); + EXPECT_TRUE(SameHandle(s1, s2)); // WithRankAtLeast on shape with known dimensionality. s1 = in1; diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index e010ecda8e7..c8ea4436136 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -521,7 +521,7 @@ TEST(ArrayOpsTest, MatrixSetDiag_ShapeFn) { INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, "[2,3];[3]"); // Output matches input. - INFER_OK(op, "?;?", "?"); + INFER_OK(op, "?;?", "in0"); INFER_OK(op, "[1,2,2];[1,2]", "in0"); INFER_OK(op, "[1,2,3];?", "in0"); INFER_OK(op, "[1,3,2];?", "in0"); @@ -1619,7 +1619,7 @@ TEST(ArrayOpsTest, UnchangedWithQuantizationScalars_ShapeFn) { TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) { ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannel"); - INFER_OK(op, "?;?;?", "?"); + INFER_OK(op, "?;?;?", "in0"); INFER_OK(op, "[?];?;?", "in0"); INFER_OK(op, "[1,?,3];[3];[3]", "in0"); INFER_OK(op, "[3];[3];[3]", "in0"); @@ -1638,7 +1638,7 @@ TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) { TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannelGradient) { ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannelGradient"); - INFER_OK(op, "?;?;?;?", "?;[?];[?]"); + INFER_OK(op, "?;?;?;?", "in0;[?];[?]"); INFER_OK(op, "[3];[3];[3];[3]", "in0;in3;in3"); INFER_OK(op, "[1,3];[1,3];[3];[3]", "in0;in3;in3"); INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4]", "in0;in3;in3"); diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 3dfa776d26f..ca3772e6f89 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -522,7 +522,7 @@ TEST(MathOpsTest, Cross_ShapeFn) { INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]"); INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]"); - INFER_OK(op, "?;?", "?"); + INFER_OK(op, "?;?", "in0"); INFER_OK(op, "[?];[?]", "in0"); INFER_OK(op, "[1,?,3];[?,?,?]", "in0"); } diff --git a/tensorflow/core/ops/spectral_ops_test.cc b/tensorflow/core/ops/spectral_ops_test.cc index 0f8a3e6ef13..b1c5e95fc5c 100644 --- a/tensorflow/core/ops/spectral_ops_test.cc +++ b/tensorflow/core/ops/spectral_ops_test.cc @@ -22,7 +22,7 @@ namespace tensorflow { TEST(MathOpsTest, FFT_ShapeFn) { for (const auto* op_name : {"FFT", "IFFT"}) { ShapeInferenceTestOp op(op_name); - INFER_OK(op, "?", "?"); + INFER_OK(op, "?", "in0"); INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]"); INFER_OK(op, "[?]", "in0"); INFER_OK(op, "[1]", "in0"); @@ -31,7 +31,7 @@ TEST(MathOpsTest, FFT_ShapeFn) { for (const auto* op_name : {"FFT2D", "IFFT2D"}) { ShapeInferenceTestOp op(op_name); - INFER_OK(op, "?", "?"); + INFER_OK(op, "?", "in0"); INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]"); INFER_OK(op, "[?,1]", "in0"); INFER_OK(op, "[1,2]", "in0"); @@ -40,7 +40,7 @@ TEST(MathOpsTest, FFT_ShapeFn) { for (const auto* op_name : {"FFT3D", "IFFT3D"}) { ShapeInferenceTestOp op(op_name); - INFER_OK(op, "?", "?"); + INFER_OK(op, "?", "in0"); INFER_ERROR("Shape must be at least rank 3 but is rank 2", op, "[1,2]"); INFER_OK(op, "[?,1,?]", "in0"); INFER_OK(op, "[1,2,3]", "in0");