Require same shape for x
and y
in shape function of ApproximateEqual
(#19878)
* Require same shape for `x` and `y` in shape function of `ApproximateEqual` In the kernel implementation of `ApproximateEqual` the shape of inputs `x` and `y` should be the same. Though in the shape function of `ApproximateEqual` there was no such validation. This fix adds the shape validation in the shape function to make sure `x` and `y` are of the same shape, if they are known. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test case for shape function of ApproximateEqual Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
5fa7b03a25
commit
400a398a18
@ -592,7 +592,13 @@ REGISTER_OP("ApproximateEqual")
|
||||
.SetIsCommutative()
|
||||
.Attr("T: numbertype")
|
||||
.Attr("tolerance: float = 0.00001")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// The inputs 'x' and 'y' must have the same shape.
|
||||
ShapeHandle data_x = c->input(0);
|
||||
ShapeHandle data_y = c->input(1);
|
||||
TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
|
||||
return shape_inference::UnchangedShape(c);
|
||||
});
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
|
@ -235,6 +235,15 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase):
|
||||
z_tf = self.evaluate(math_ops.approximate_equal(x, y, tolerance=0.0001))
|
||||
self.assertAllEqual(z, z_tf)
|
||||
|
||||
def testApproximateEqualShape(self):
|
||||
for dtype in [np.float32, np.double]:
|
||||
x = np.array([1, 2], dtype=np.float32)
|
||||
y = np.array([[1, 2]], dtype=np.float32)
|
||||
# The inputs 'x' and 'y' must have the same shape.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Shapes must be equal rank, but are 1 and 2"):
|
||||
math_ops.approximate_equal(x, y)
|
||||
|
||||
|
||||
class ScalarMulTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user