Require same shape for x
and y
in shape function of ApproximateEqual
(#19878)
* Require same shape for `x` and `y` in shape function of `ApproximateEqual` In the kernel implementation of `ApproximateEqual` the shape of inputs `x` and `y` should be the same. Though in the shape function of `ApproximateEqual` there was no such validation. This fix adds the shape validation in the shape function to make sure `x` and `y` are of the same shape, if they are known. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test case for shape function of ApproximateEqual Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
5fa7b03a25
commit
400a398a18
@ -592,7 +592,13 @@ REGISTER_OP("ApproximateEqual")
|
|||||||
.SetIsCommutative()
|
.SetIsCommutative()
|
||||||
.Attr("T: numbertype")
|
.Attr("T: numbertype")
|
||||||
.Attr("tolerance: float = 0.00001")
|
.Attr("tolerance: float = 0.00001")
|
||||||
.SetShapeFn(shape_inference::UnchangedShape);
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
// The inputs 'x' and 'y' must have the same shape.
|
||||||
|
ShapeHandle data_x = c->input(0);
|
||||||
|
ShapeHandle data_y = c->input(1);
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
|
||||||
|
return shape_inference::UnchangedShape(c);
|
||||||
|
});
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@ -235,6 +235,15 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase):
|
|||||||
z_tf = self.evaluate(math_ops.approximate_equal(x, y, tolerance=0.0001))
|
z_tf = self.evaluate(math_ops.approximate_equal(x, y, tolerance=0.0001))
|
||||||
self.assertAllEqual(z, z_tf)
|
self.assertAllEqual(z, z_tf)
|
||||||
|
|
||||||
|
def testApproximateEqualShape(self):
|
||||||
|
for dtype in [np.float32, np.double]:
|
||||||
|
x = np.array([1, 2], dtype=np.float32)
|
||||||
|
y = np.array([[1, 2]], dtype=np.float32)
|
||||||
|
# The inputs 'x' and 'y' must have the same shape.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, "Shapes must be equal rank, but are 1 and 2"):
|
||||||
|
math_ops.approximate_equal(x, y)
|
||||||
|
|
||||||
|
|
||||||
class ScalarMulTest(test_util.TensorFlowTestCase):
|
class ScalarMulTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user