Require same shape for x and y in shape function of ApproximateEqual (#19878)

* Require same shape for `x` and `y` in shape function of `ApproximateEqual`

In the kernel implementation of `ApproximateEqual` the shape of inputs
`x` and `y` should be the same. Though in the shape function of `ApproximateEqual`
there was no such validation. This fix adds the shape validation in the
shape function to make sure `x` and `y` are of the same shape, if they are known.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add test case for shape function of ApproximateEqual

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2018-06-12 14:40:54 -07:00 committed by drpngx
parent 5fa7b03a25
commit 400a398a18
2 changed files with 16 additions and 1 deletions

View File

@ -592,7 +592,13 @@ REGISTER_OP("ApproximateEqual")
.SetIsCommutative()
.Attr("T: numbertype")
.Attr("tolerance: float = 0.00001")
.SetShapeFn(shape_inference::UnchangedShape);
.SetShapeFn([](InferenceContext* c) {
// The inputs 'x' and 'y' must have the same shape.
ShapeHandle data_x = c->input(0);
ShapeHandle data_y = c->input(1);
TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
return shape_inference::UnchangedShape(c);
});
// --------------------------------------------------------------------------

View File

@ -235,6 +235,15 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase):
z_tf = self.evaluate(math_ops.approximate_equal(x, y, tolerance=0.0001))
self.assertAllEqual(z, z_tf)
def testApproximateEqualShape(self):
for dtype in [np.float32, np.double]:
x = np.array([1, 2], dtype=np.float32)
y = np.array([[1, 2]], dtype=np.float32)
# The inputs 'x' and 'y' must have the same shape.
with self.assertRaisesRegexp(
ValueError, "Shapes must be equal rank, but are 1 and 2"):
math_ops.approximate_equal(x, y)
class ScalarMulTest(test_util.TensorFlowTestCase):