Move tf.equal shape inference test to math_ops_test.cc
also added additional shape inference test cases Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
83f19c6a9e
commit
ef55a40b37
@ -604,4 +604,22 @@ TEST(MathOpsTest, SobolSample) {
|
||||
|
||||
INFER_OK(op, "[];[];[]", "[?,?]");
|
||||
}
|
||||
|
||||
TEST(MathOpsTest, EqualOp) {
|
||||
ShapeInferenceTestOp op("Equal");
|
||||
AddNodeAttr("incompatible_shape_error", true, &op.node_def);
|
||||
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "[1,2];?", "?");
|
||||
INFER_OK(op, "?;[1,2]", "?");
|
||||
|
||||
INFER_OK(op, "[1,2,3];[1]", "[d0_0,d0_1,d0_2]");
|
||||
INFER_OK(op, "[?,2,1];[1,3]", "[d0_0,d0_1,d1_1]");
|
||||
INFER_OK(op, "[1,?,3];[3,1]", "[d0_0,d1_0,d0_2]");
|
||||
INFER_OK(op, "[1,2,3];[2,1,3]", "[d1_0,d0_1,d0_2]");
|
||||
|
||||
// Note: Test case for GitHub issue 40471
|
||||
INFER_OK(op, "[?,10,1];[?,1,4]", "[?,d0_1,d1_2]");
|
||||
INFER_OK(op, "[10,?,1];[1,?,4]", "[d0_0,?,d1_2]");
|
||||
}
|
||||
} // end namespace tensorflow
|
||||
|
@ -19,9 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.operators import logical
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -85,18 +83,6 @@ class LogicalOperatorsTest(test.TestCase):
|
||||
t = logical.not_(self._tf_false())
|
||||
self.assertEqual(self.evaluate(t), True)
|
||||
|
||||
# Test case for GitHub issue 40471
|
||||
def test_equal_output_shapes(self):
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec([None, 10, 1]),
|
||||
tensor_spec.TensorSpec([None, 1, 4])])
|
||||
def f(x, y):
|
||||
z = x == y
|
||||
return z
|
||||
|
||||
self.assertAllEqual(f.get_concrete_function().output_shapes, [None, 10, 4])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user