From a74f9c3c612586ba4581bd6324a7c1ced69ec5a3 Mon Sep 17 00:00:00 2001 From: Edward Loper <edloper@google.com> Date: Thu, 25 Apr 2019 12:30:56 -0700 Subject: [PATCH] Add support to for composite tensors, such as SparseTensor and RaggedTensor, to while_v2 PiperOrigin-RevId: 245285953 --- tensorflow/python/framework/ops.py | 4 +- tensorflow/python/framework/sparse_tensor.py | 4 +- .../kernel_tests/control_flow_ops_py_test.py | 140 +++++++++++------- tensorflow/python/ops/control_flow_ops.py | 2 +- tensorflow/python/ops/while_v2.py | 43 ++++-- 5 files changed, 118 insertions(+), 75 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index d23cfc77c94..f806a15a94a 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1761,10 +1761,10 @@ class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor): if shape is None: shape = self._values.shape if self._dense_shape is None: - return [shape, shape[:1]] # values, indices + return (shape, shape[:1]) # values, indices else: # values, indices, dense_shape - return [shape, shape[:1], tensor_shape.TensorShape([shape.ndims])] + return (shape, shape[:1], tensor_shape.TensorShape([shape.ndims])) @property def _is_graph_tensor(self): diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index e9199b1e661..cb427b4a4d2 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -250,11 +250,11 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor): raise ValueError("Shape invariant for SparseTensor must have the form " "TensorShape([r]), got %r" % shape) rank = tensor_shape.dimension_value(shape[0]) - return [ + return ( tensor_shape.TensorShape([None, rank]), # indices tensor_shape.TensorShape([None]), # values tensor_shape.TensorShape([rank]) # dense_shape - ] + ) @property def _is_graph_tensor(self): diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 2f89ee4746c..43f57112bb0 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1790,6 +1790,18 @@ class ControlFlowTest(test.TestCase): r = r[1] * array_ops.ones([8, 8]) self.assertAllEqual(np.ones((8, 8)), self.evaluate(r)) + @test_util.disable_control_flow_v2("b/131265085") + @test_util.run_v1_only("b/131265085") + def testWhileBadShape(self): + x = constant_op.constant([2.0, 4.0], name="values") + i = constant_op.constant(0) + c = lambda i, _: math_ops.less(i, 10) + b = lambda i, x: [i + 1, x + 1] + with self.assertRaisesRegexp(ValueError, "is not compatible with"): + # Shape of x is [2], but we specify a shape of [5]. + control_flow_ops.while_loop( + c, b, [i, x], [i.shape, tensor_shape.TensorShape([5])]) + @test_util.run_deprecated_v1 def testWhileWithNonTensorInput_Scalar(self): with self.cached_session(): @@ -1807,7 +1819,6 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) self.assertEqual([10000], self.evaluate(r)) - @test_util.run_v1_only("b/120545219") def testWhileShapeInference(self): with self.cached_session(): i = constant_op.constant(0) @@ -1822,19 +1833,23 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.while_loop( c, b, [i, m], [i.get_shape(), tensor_shape.TensorShape([None, 2])]) - self.assertIsNone(r[1].shape.dims[0].value) - self.assertEqual(r[1].shape.dims[1], tensor_shape.Dimension(2)) + self.assertTrue(r[1].shape.is_compatible_with([8, 2])) + @test_util.run_v1_only("b/120545219") + def testWhileShapeInferenceBadShape(self): + with self.cached_session(): + i = constant_op.constant(0) + m = array_ops.ones([2, 2]) + c = lambda i, j: math_ops.less(i, 2) + b = lambda i, j: [i + 1, array_ops.concat([j, j], 0)] with self.assertRaisesRegexp( ValueError, r"Input tensor 'ones:0' enters the loop with shape \(2, 2\), but has " r"shape \(4, 2\) after one iteration. To allow the shape to vary " r"across iterations, use the `shape_invariants` argument of " r"tf.while_loop to specify a less-specific shape."): - r = control_flow_ops.while_loop(c, b, [i, m]) + control_flow_ops.while_loop(c, b, [i, m]) - @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)") - @test_util.run_v1_only("b/120545219") def testWhileShapeInferenceSparseTensor(self): values = constant_op.constant([2.0, 4.0], name="values") indices = constant_op.constant([[0], [3]], @@ -1873,61 +1888,72 @@ class ControlFlowTest(test.TestCase): array_ops.concat([x.dense_shape, [10]], axis=0)) ] + def check_shapes(r, indices, values, dense_shape): + self.assertTrue(r.indices.shape.is_compatible_with(indices)) + self.assertTrue(r.values.shape.is_compatible_with(values)) + self.assertTrue(r.dense_shape.shape.is_compatible_with(dense_shape)) + # Default shape invariant; b1 only modifies values. _, r = control_flow_ops.while_loop(c, b1, [i, x]) - self.assertEqual(r.indices.get_shape().as_list(), [None, 1]) - self.assertEqual(r.values.get_shape().as_list(), [None]) - self.assertEqual(r.dense_shape.get_shape().as_list(), [1]) + check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1]) # Default shape invariant; b2 adds new values _, r = control_flow_ops.while_loop(c, b2, [i, x]) - self.assertEqual(r.indices.get_shape().as_list(), [None, 1]) - self.assertEqual(r.values.get_shape().as_list(), [None]) - self.assertEqual(r.dense_shape.get_shape().as_list(), [1]) - - # Default shape invariant; b3 modifies rank (which is not allowed). - with self.assertRaises(ValueError): - _, r = control_flow_ops.while_loop(c, b3, [i, x]) + check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1]) # Explicit shape invariant, allowing any rank; b1 only modifies values. _, r = control_flow_ops.while_loop( c, b1, [i, x], [i.get_shape(), tensor_shape.TensorShape([None])]) - self.assertEqual(r.indices.get_shape().as_list(), [None, None]) - self.assertEqual(r.values.get_shape().as_list(), [None]) - self.assertEqual(r.dense_shape.get_shape().as_list(), [None]) + check_shapes(r, indices=[None, None], values=[None], dense_shape=[None]) # Explicit shape invariant, allowing any rank; b3 modifies rank. _, r = control_flow_ops.while_loop( c, b3, [i, x], [i.get_shape(), tensor_shape.TensorShape([None])]) - self.assertEqual(r.indices.get_shape().as_list(), [None, None]) - self.assertEqual(r.values.get_shape().as_list(), [None]) - self.assertEqual(r.dense_shape.get_shape().as_list(), [None]) + check_shapes(r, indices=[None, None], values=[None], dense_shape=[None]) # Shape invariant with ndims=None. Technically, this isn't supported # according to the docs, but we support it for backwards compatibility. _, r = control_flow_ops.while_loop( c, b1, [i, x], [i.get_shape(), tensor_shape.TensorShape(None)]) - self.assertEqual(r.indices.get_shape().as_list(), [None, None]) - self.assertEqual(r.values.get_shape().as_list(), [None]) - self.assertEqual(r.dense_shape.get_shape().as_list(), [None]) + check_shapes(r, indices=[None, None], values=[None], dense_shape=[None]) _, r = control_flow_ops.while_loop( c, b3, [i, x], [i.get_shape(), tensor_shape.TensorShape(None)]) - self.assertEqual(r.indices.get_shape().as_list(), [None, None]) - self.assertEqual(r.values.get_shape().as_list(), [None]) - self.assertEqual(r.dense_shape.get_shape().as_list(), [None]) + check_shapes(r, indices=[None, None], values=[None], dense_shape=[None]) + + @test_util.disable_control_flow_v2("b/131265085") + @test_util.run_v1_only("b/131265085") + def testWhileBadShapeSparseTensor(self): + values = constant_op.constant([2.0, 4.0], name="values") + indices = constant_op.constant([[0], [3]], + dtype=dtypes.int64, + name="indices") + shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") + i = constant_op.constant(0) + x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) + c = lambda i, _: i < 10 + b1 = lambda i, x: [i+1, x] + def b2(i, x): # modifies rank. (shape of all components is changed.) + return [ + i + 1, + sparse_tensor.SparseTensor( + array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0, + array_ops.concat([x.dense_shape, [10]], axis=0)) + ] # Explicit shape invariant, with a specific (incompatible) rank. with self.assertRaisesRegexp(ValueError, "is not compatible with"): - _, r = control_flow_ops.while_loop( + control_flow_ops.while_loop( c, b1, [i, x], [i.get_shape(), tensor_shape.TensorShape([5])]) - @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)") - @test_util.run_v1_only("b/120545219") + # Default shape invariant, but b2 modifies rank (which is not allowed). + with self.assertRaises(ValueError): + control_flow_ops.while_loop(c, b2, [i, x]) + def testWhileShapeInferenceIndexedSlices(self): with self.cached_session(): values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values") @@ -1953,17 +1979,28 @@ class ControlFlowTest(test.TestCase): c, b, [i, x], [i.get_shape(), tensor_shape.TensorShape([None, 2])]) self.assertEqual(r.dense_shape.get_shape()[0], 2) - self.assertEqual(r.values.get_shape().as_list(), [None, 2]) + self.assertTrue(r.values.get_shape().is_compatible_with([None, 2])) - with self.assertRaisesRegexp(ValueError, "is not compatible with"): - _, r = control_flow_ops.while_loop( - c, b, [i, x], - [i.get_shape(), tensor_shape.TensorShape([None, 5])]) + @test_util.disable_control_flow_v2("b/131265085") + @test_util.run_v1_only("b/131265085") + def testWhileBadShapeIndexedSlices(self): + values = constant_op.constant([2.0, 4.0], name="values") + indices = constant_op.constant([[0], [3]], + dtype=dtypes.int64, + name="indices") + shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") + i = constant_op.constant(0) + x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) + c = lambda i, _: 10 + b = lambda i, x: [i+1, x] + + # Explicit shape invariant, with a specific (incompatible) rank. + with self.assertRaisesRegexp(ValueError, "is not compatible with"): + control_flow_ops.while_loop( + c, b, [i, x], + [i.get_shape(), tensor_shape.TensorShape([5])]) - @test_util.disable_control_flow_v2("b/116328420 (RaggedTensor)") def testWhileShapeInferenceRaggedTensor(self): - if context.executing_eagerly(): - self.skipTest("b/116328420") i = constant_op.constant(0) x = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]]) c = lambda i, _: i < 10 @@ -1980,11 +2017,13 @@ class ControlFlowTest(test.TestCase): array_ops.concat([x, x], axis=0) ] + def check_shapes(r, values, splits): + self.assertTrue(r.values.shape.is_compatible_with(values)) + self.assertTrue(r.row_splits.shape.is_compatible_with(splits)) + # Default shape invariant; b1 adds new values to rows. _, r = control_flow_ops.while_loop(c, b1, [i, x]) - self.assertEqual(r.row_splits.shape.as_list(), [4]) - - self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None])) + check_shapes(r, values=[None], splits=[4]) # Default shape invariant; b2 adds new rows (not allowed). if not context.executing_eagerly(): @@ -1995,20 +2034,15 @@ class ControlFlowTest(test.TestCase): _, r = control_flow_ops.while_loop( c, b1, [i, x], [i.get_shape(), tensor_shape.TensorShape([None, None])]) - self.assertTrue(r.row_splits.shape.as_list() in ([4], [None])) - self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None])) + check_shapes(r, values=[None], splits=[None]) # Explicit shape invariant; b2 adds new rows. _, r = control_flow_ops.while_loop( c, b2, [i, x], [i.get_shape(), tensor_shape.TensorShape([None, None])]) - self.assertTrue(r.row_splits.shape.as_list() in ([3 * 2**10 + 1], [None])) - self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None])) + check_shapes(r, values=[None], splits=[None]) - @test_util.disable_control_flow_v2("b/116328420 (RaggedTensor)") def testWhileShapeInferenceRaggedTensorRaggedRank2(self): - if context.executing_eagerly(): - self.skipTest("b/116328420") i = constant_op.constant(0) x = ragged_factory_ops.constant([[[1, 2], [3], [4, 5, 6]], [[], [8, 9, 10]]]) @@ -3473,8 +3507,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(0, value_x) self.assertEqual(73, value_x_grad) - @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)") - @test_util.run_v1_only("b/120545219") + @test_util.deprecated_graph_mode_only def testWhileGrad_IndexedSlices(self): with self.cached_session(): values = constant_op.constant([2.0, 4.0], name="values") @@ -3496,8 +3529,7 @@ class ControlFlowTest(test.TestCase): r = gradients_impl.gradients(r.values, values)[0] self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r)) - @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)") - @test_util.run_v1_only("b/120545219") + @test_util.deprecated_graph_mode_only def testWhileGrad_SparseTensor(self): with self.cached_session(): values = constant_op.constant([2.0, 4.0], name="values") @@ -3520,7 +3552,7 @@ class ControlFlowTest(test.TestCase): r = gradients_impl.gradients(r.values, values)[0] self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r)) - @test_util.run_v1_only("b/120545219") + @test_util.deprecated_graph_mode_only def testCallGradInLoop(self): with self.cached_session() as sess: i0 = constant_op.constant(0) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index c1181a74dc6..daa99be060f 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -3466,7 +3466,7 @@ def while_loop(cond, return x return ops.convert_to_tensor(x) - loop_vars = nest.map_structure(convert, loop_vars) + loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True) if maximum_iterations is not None: return loop_vars[1] else: diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index cbbc0de9370..50e7d244d5e 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -72,12 +72,18 @@ def while_loop(cond, # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( - ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) + ops.internal_convert_to_tensor_or_indexed_slices, loop_vars, + expand_composites=True) if shape_invariants is not None: - nest.assert_same_structure(orig_loop_vars, shape_invariants) + nest.assert_same_structure(orig_loop_vars, shape_invariants, + expand_composites=False) + shape_invariants = nest.map_structure( + control_flow_ops._get_shape_invariant, loop_vars, + list(shape_invariants), expand_composites=False) else: - shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) - + shape_invariants = nest.map_structure( + control_flow_ops._get_shape_invariant, loop_vars, + expand_composites=False) if not name: name = "while" @@ -150,11 +156,12 @@ def while_loop(cond, # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) - if not nest.is_sequence(outputs): + if not nest.is_sequence_or_composite(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. - nest.assert_same_structure(list(outputs), list(orig_loop_vars)) + nest.assert_same_structure(list(outputs), list(orig_loop_vars), + expand_composites=True) outputs = _tensor_array_to_flow(outputs) @@ -193,7 +200,8 @@ def while_loop(cond, # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. - num_flattened_outputs = len(nest.flatten(orig_loop_vars)) + num_flattened_outputs = len(nest.flatten(orig_loop_vars, + expand_composites=True)) # First var is loop counter and second var is maximum_iterations. first_loop_var_index = 2 _check_shapes_compat( @@ -201,10 +209,10 @@ def while_loop(cond, num_flattened_outputs], nest.flatten( shape_invariants[first_loop_var_index:first_loop_var_index + - len_orig_loop_vars]), + len_orig_loop_vars], expand_composites=True), nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + - len_orig_loop_vars])) - flattened_loop_vars = nest.flatten(loop_vars) + len_orig_loop_vars], expand_composites=True)) + flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) @@ -237,7 +245,7 @@ def while_loop(cond, if return_same_structure: return outputs - flattened_outputs = nest.flatten(outputs) + flattened_outputs = nest.flatten(outputs, expand_composites=True) if len(flattened_outputs) == 1: return flattened_outputs[0] else: @@ -905,9 +913,11 @@ def _pack_sequence_as(structure_with_tas, loop_vars): flattened_loop_vars = [ flow_to_tensor_array(*z) - for z in zip(nest.flatten(loop_vars), nest.flatten(structure_with_tas)) + for z in zip(nest.flatten(loop_vars, expand_composites=True), + nest.flatten(structure_with_tas, expand_composites=True)) ] - return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars) + return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars, + expand_composites=True) def _tensor_array_to_flow(loop_vars): @@ -917,14 +927,15 @@ def _tensor_array_to_flow(loop_vars): return maybe_ta.flow return maybe_ta - return nest.map_structure(f, loop_vars) + return nest.map_structure(f, loop_vars, expand_composites=True) def _build_signature(loop_vars, shape_invariants): return nest.pack_sequence_as(loop_vars, [ tensor_spec.TensorSpec(s, t.dtype, name=t.op.name) - for s, t in zip(nest.flatten(shape_invariants), nest.flatten(loop_vars)) - ]) + for s, t in zip(nest.flatten(shape_invariants, expand_composites=True), + nest.flatten(loop_vars, expand_composites=True)) + ], expand_composites=True) def _build_maximum_iterations_loop_var(maximum_iterations):