diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 929213656cc..6d1ef56608c 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -592,7 +592,13 @@ REGISTER_OP("ApproximateEqual") .SetIsCommutative() .Attr("T: numbertype") .Attr("tolerance: float = 0.00001") - .SetShapeFn(shape_inference::UnchangedShape); + .SetShapeFn([](InferenceContext* c) { + // The inputs 'x' and 'y' must have the same shape. + ShapeHandle data_x = c->input(0); + ShapeHandle data_y = c->input(1); + TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x)); + return shape_inference::UnchangedShape(c); + }); // -------------------------------------------------------------------------- diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 980c92b0d59..c807c8bc2ef 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -235,6 +235,15 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase): z_tf = self.evaluate(math_ops.approximate_equal(x, y, tolerance=0.0001)) self.assertAllEqual(z, z_tf) + def testApproximateEqualShape(self): + for dtype in [np.float32, np.double]: + x = np.array([1, 2], dtype=np.float32) + y = np.array([[1, 2]], dtype=np.float32) + # The inputs 'x' and 'y' must have the same shape. + with self.assertRaisesRegexp( + ValueError, "Shapes must be equal rank, but are 1 and 2"): + math_ops.approximate_equal(x, y) + class ScalarMulTest(test_util.TensorFlowTestCase):