Add shape test for ragged reduce_sum.
PiperOrigin-RevId: 266944199
This commit is contained in:
parent
fa9b9dc241
commit
0c3bae3842
@ -321,6 +321,13 @@ class RaggedReduceOpsTest(test_util.TensorFlowTestCase,
|
||||
reduced = ragged_reduce_op(rt_input, axis)
|
||||
self.assertAllEqual(reduced, expected)
|
||||
|
||||
def testReduceKeepsInnerDimensionShape(self):
|
||||
# Test for bug [b/139823356].
|
||||
rt = ragged_factory_ops.constant([[[[1, 1]]]], ragged_rank=2)
|
||||
self.assertEqual(rt.shape.as_list(), [1, None, None, 2])
|
||||
reduced = ragged_math_ops.reduce_sum(rt, axis=2)
|
||||
self.assertEqual(reduced.shape.as_list(), [1, None, 2])
|
||||
|
||||
def assertEqualWithNan(self, actual, expected):
|
||||
"""Like assertEqual, but NaN==NaN."""
|
||||
self.assertTrue(
|
||||
|
Loading…
Reference in New Issue
Block a user