From 83f19c6a9e84fc6971ad0a7df5874603237a595f Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 15 Jun 2020 17:16:56 +0000 Subject: [PATCH 1/2] Fix unknown output shape issue in autograph for tf.equal This PR tries to address the issue raised in 40471 where the output shape of an autograph consists of tf.equal could not inference correctly. Specifically `x.shape == [None, 10, 1]` and `y.shape == [None, 1, 4]` only yield `shape == None` (should be `shape == [None, 10, 4]`). The reason was that the shape inbference function for equal didn't capture the cases where both x and y's dim are None. This PR fixes the issue. This PR fixes 40471. Signed-off-by: Yong Tang --- tensorflow/core/framework/common_shape_fns.cc | 3 +++ tensorflow/core/ops/math_ops_test.cc | 2 +- .../python/autograph/operators/logical_test.py | 14 ++++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 113adbdd432..7567db03c23 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1936,6 +1936,7 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, // in C++ op code, we must still assert that the unknown dim is either 1 // or the same as the known dim. // - If either dimension is 1, the other dimension is the output. + // - If both are unknown then dimension is unknown if (c->Value(dim_x) > 1) { if (!incompatible_shape_error) { *out = c->UnknownShape(); @@ -1954,6 +1955,8 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, dims.push_back(dim_x); } else if (dim_y.SameHandle(dim_x)) { dims.push_back(dim_x); + } else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) { + dims.push_back(c->UnknownDim()); } else { if (!incompatible_shape_error) { *out = c->UnknownShape(); diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 5c69a2a7f1c..a2837d88bde 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -120,7 +120,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) { INFER_OK(op, "[1];[?]", "[d1_0]"); INFER_OK(op, "[?];[2]", incompatible_shape_error ? "[d1_0]" : "?"); INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?"); - INFER_OK(op, "[?];[?]", incompatible_shape_error ? "[?]" : "?"); + INFER_OK(op, "[?];[?]", "[?]"); INFER_OK(op, "[];[?]", "[d1_0]"); INFER_OK(op, "[?];[]", "[d0_0]"); diff --git a/tensorflow/python/autograph/operators/logical_test.py b/tensorflow/python/autograph/operators/logical_test.py index e22f39932d1..0eab302a825 100644 --- a/tensorflow/python/autograph/operators/logical_test.py +++ b/tensorflow/python/autograph/operators/logical_test.py @@ -19,7 +19,9 @@ from __future__ import division from __future__ import print_function from tensorflow.python.autograph.operators import logical +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.platform import test @@ -83,6 +85,18 @@ class LogicalOperatorsTest(test.TestCase): t = logical.not_(self._tf_false()) self.assertEqual(self.evaluate(t), True) + # Test case for GitHub issue 40471 + def test_equal_output_shapes(self): + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec([None, 10, 1]), + tensor_spec.TensorSpec([None, 1, 4])]) + def f(x, y): + z = x == y + return z + + self.assertAllEqual(f.get_concrete_function().output_shapes, [None, 10, 4]) + if __name__ == '__main__': test.main() From ef55a40b374d7310e4ce3149d86395d403403d0d Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 17 Jun 2020 18:38:24 +0000 Subject: [PATCH 2/2] Move tf.equal shape inference test to math_ops_test.cc also added additional shape inference test cases Signed-off-by: Yong Tang --- tensorflow/core/ops/math_ops_test.cc | 18 ++++++++++++++++++ .../python/autograph/operators/logical_test.py | 14 -------------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index a2837d88bde..2b65f88042c 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -604,4 +604,22 @@ TEST(MathOpsTest, SobolSample) { INFER_OK(op, "[];[];[]", "[?,?]"); } + +TEST(MathOpsTest, EqualOp) { + ShapeInferenceTestOp op("Equal"); + AddNodeAttr("incompatible_shape_error", true, &op.node_def); + + INFER_OK(op, "?;?", "?"); + INFER_OK(op, "[1,2];?", "?"); + INFER_OK(op, "?;[1,2]", "?"); + + INFER_OK(op, "[1,2,3];[1]", "[d0_0,d0_1,d0_2]"); + INFER_OK(op, "[?,2,1];[1,3]", "[d0_0,d0_1,d1_1]"); + INFER_OK(op, "[1,?,3];[3,1]", "[d0_0,d1_0,d0_2]"); + INFER_OK(op, "[1,2,3];[2,1,3]", "[d1_0,d0_1,d0_2]"); + + // Note: Test case for GitHub issue 40471 + INFER_OK(op, "[?,10,1];[?,1,4]", "[?,d0_1,d1_2]"); + INFER_OK(op, "[10,?,1];[1,?,4]", "[d0_0,?,d1_2]"); +} } // end namespace tensorflow diff --git a/tensorflow/python/autograph/operators/logical_test.py b/tensorflow/python/autograph/operators/logical_test.py index 0eab302a825..e22f39932d1 100644 --- a/tensorflow/python/autograph/operators/logical_test.py +++ b/tensorflow/python/autograph/operators/logical_test.py @@ -19,9 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.autograph.operators import logical -from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.platform import test @@ -85,18 +83,6 @@ class LogicalOperatorsTest(test.TestCase): t = logical.not_(self._tf_false()) self.assertEqual(self.evaluate(t), True) - # Test case for GitHub issue 40471 - def test_equal_output_shapes(self): - - @def_function.function(input_signature=[ - tensor_spec.TensorSpec([None, 10, 1]), - tensor_spec.TensorSpec([None, 1, 4])]) - def f(x, y): - z = x == y - return z - - self.assertAllEqual(f.get_concrete_function().output_shapes, [None, 10, 4]) - if __name__ == '__main__': test.main()