Add shape test for ragged reduce_sum.
PiperOrigin-RevId: 266944199
This commit is contained in:
parent
fa9b9dc241
commit
0c3bae3842
@ -321,6 +321,13 @@ class RaggedReduceOpsTest(test_util.TensorFlowTestCase,
|
|||||||
reduced = ragged_reduce_op(rt_input, axis)
|
reduced = ragged_reduce_op(rt_input, axis)
|
||||||
self.assertAllEqual(reduced, expected)
|
self.assertAllEqual(reduced, expected)
|
||||||
|
|
||||||
|
def testReduceKeepsInnerDimensionShape(self):
|
||||||
|
# Test for bug [b/139823356].
|
||||||
|
rt = ragged_factory_ops.constant([[[[1, 1]]]], ragged_rank=2)
|
||||||
|
self.assertEqual(rt.shape.as_list(), [1, None, None, 2])
|
||||||
|
reduced = ragged_math_ops.reduce_sum(rt, axis=2)
|
||||||
|
self.assertEqual(reduced.shape.as_list(), [1, None, 2])
|
||||||
|
|
||||||
def assertEqualWithNan(self, actual, expected):
|
def assertEqualWithNan(self, actual, expected):
|
||||||
"""Like assertEqual, but NaN==NaN."""
|
"""Like assertEqual, but NaN==NaN."""
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
Loading…
Reference in New Issue
Block a user