Add support to for composite tensors, such as SparseTensor and RaggedTensor, to while_v2
PiperOrigin-RevId: 245285953
This commit is contained in:
parent
421802c1b4
commit
a74f9c3c61
@ -1761,10 +1761,10 @@ class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor):
|
|||||||
if shape is None:
|
if shape is None:
|
||||||
shape = self._values.shape
|
shape = self._values.shape
|
||||||
if self._dense_shape is None:
|
if self._dense_shape is None:
|
||||||
return [shape, shape[:1]] # values, indices
|
return (shape, shape[:1]) # values, indices
|
||||||
else:
|
else:
|
||||||
# values, indices, dense_shape
|
# values, indices, dense_shape
|
||||||
return [shape, shape[:1], tensor_shape.TensorShape([shape.ndims])]
|
return (shape, shape[:1], tensor_shape.TensorShape([shape.ndims]))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _is_graph_tensor(self):
|
def _is_graph_tensor(self):
|
||||||
|
@ -250,11 +250,11 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
|
|||||||
raise ValueError("Shape invariant for SparseTensor must have the form "
|
raise ValueError("Shape invariant for SparseTensor must have the form "
|
||||||
"TensorShape([r]), got %r" % shape)
|
"TensorShape([r]), got %r" % shape)
|
||||||
rank = tensor_shape.dimension_value(shape[0])
|
rank = tensor_shape.dimension_value(shape[0])
|
||||||
return [
|
return (
|
||||||
tensor_shape.TensorShape([None, rank]), # indices
|
tensor_shape.TensorShape([None, rank]), # indices
|
||||||
tensor_shape.TensorShape([None]), # values
|
tensor_shape.TensorShape([None]), # values
|
||||||
tensor_shape.TensorShape([rank]) # dense_shape
|
tensor_shape.TensorShape([rank]) # dense_shape
|
||||||
]
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _is_graph_tensor(self):
|
def _is_graph_tensor(self):
|
||||||
|
@ -1790,6 +1790,18 @@ class ControlFlowTest(test.TestCase):
|
|||||||
r = r[1] * array_ops.ones([8, 8])
|
r = r[1] * array_ops.ones([8, 8])
|
||||||
self.assertAllEqual(np.ones((8, 8)), self.evaluate(r))
|
self.assertAllEqual(np.ones((8, 8)), self.evaluate(r))
|
||||||
|
|
||||||
|
@test_util.disable_control_flow_v2("b/131265085")
|
||||||
|
@test_util.run_v1_only("b/131265085")
|
||||||
|
def testWhileBadShape(self):
|
||||||
|
x = constant_op.constant([2.0, 4.0], name="values")
|
||||||
|
i = constant_op.constant(0)
|
||||||
|
c = lambda i, _: math_ops.less(i, 10)
|
||||||
|
b = lambda i, x: [i + 1, x + 1]
|
||||||
|
with self.assertRaisesRegexp(ValueError, "is not compatible with"):
|
||||||
|
# Shape of x is [2], but we specify a shape of [5].
|
||||||
|
control_flow_ops.while_loop(
|
||||||
|
c, b, [i, x], [i.shape, tensor_shape.TensorShape([5])])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testWhileWithNonTensorInput_Scalar(self):
|
def testWhileWithNonTensorInput_Scalar(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
@ -1807,7 +1819,6 @@ class ControlFlowTest(test.TestCase):
|
|||||||
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
|
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
|
||||||
self.assertEqual([10000], self.evaluate(r))
|
self.assertEqual([10000], self.evaluate(r))
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testWhileShapeInference(self):
|
def testWhileShapeInference(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
i = constant_op.constant(0)
|
i = constant_op.constant(0)
|
||||||
@ -1822,19 +1833,23 @@ class ControlFlowTest(test.TestCase):
|
|||||||
r = control_flow_ops.while_loop(
|
r = control_flow_ops.while_loop(
|
||||||
c, b, [i, m],
|
c, b, [i, m],
|
||||||
[i.get_shape(), tensor_shape.TensorShape([None, 2])])
|
[i.get_shape(), tensor_shape.TensorShape([None, 2])])
|
||||||
self.assertIsNone(r[1].shape.dims[0].value)
|
self.assertTrue(r[1].shape.is_compatible_with([8, 2]))
|
||||||
self.assertEqual(r[1].shape.dims[1], tensor_shape.Dimension(2))
|
|
||||||
|
|
||||||
|
@test_util.run_v1_only("b/120545219")
|
||||||
|
def testWhileShapeInferenceBadShape(self):
|
||||||
|
with self.cached_session():
|
||||||
|
i = constant_op.constant(0)
|
||||||
|
m = array_ops.ones([2, 2])
|
||||||
|
c = lambda i, j: math_ops.less(i, 2)
|
||||||
|
b = lambda i, j: [i + 1, array_ops.concat([j, j], 0)]
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
r"Input tensor 'ones:0' enters the loop with shape \(2, 2\), but has "
|
r"Input tensor 'ones:0' enters the loop with shape \(2, 2\), but has "
|
||||||
r"shape \(4, 2\) after one iteration. To allow the shape to vary "
|
r"shape \(4, 2\) after one iteration. To allow the shape to vary "
|
||||||
r"across iterations, use the `shape_invariants` argument of "
|
r"across iterations, use the `shape_invariants` argument of "
|
||||||
r"tf.while_loop to specify a less-specific shape."):
|
r"tf.while_loop to specify a less-specific shape."):
|
||||||
r = control_flow_ops.while_loop(c, b, [i, m])
|
control_flow_ops.while_loop(c, b, [i, m])
|
||||||
|
|
||||||
@test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testWhileShapeInferenceSparseTensor(self):
|
def testWhileShapeInferenceSparseTensor(self):
|
||||||
values = constant_op.constant([2.0, 4.0], name="values")
|
values = constant_op.constant([2.0, 4.0], name="values")
|
||||||
indices = constant_op.constant([[0], [3]],
|
indices = constant_op.constant([[0], [3]],
|
||||||
@ -1873,61 +1888,72 @@ class ControlFlowTest(test.TestCase):
|
|||||||
array_ops.concat([x.dense_shape, [10]], axis=0))
|
array_ops.concat([x.dense_shape, [10]], axis=0))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def check_shapes(r, indices, values, dense_shape):
|
||||||
|
self.assertTrue(r.indices.shape.is_compatible_with(indices))
|
||||||
|
self.assertTrue(r.values.shape.is_compatible_with(values))
|
||||||
|
self.assertTrue(r.dense_shape.shape.is_compatible_with(dense_shape))
|
||||||
|
|
||||||
# Default shape invariant; b1 only modifies values.
|
# Default shape invariant; b1 only modifies values.
|
||||||
_, r = control_flow_ops.while_loop(c, b1, [i, x])
|
_, r = control_flow_ops.while_loop(c, b1, [i, x])
|
||||||
self.assertEqual(r.indices.get_shape().as_list(), [None, 1])
|
check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1])
|
||||||
self.assertEqual(r.values.get_shape().as_list(), [None])
|
|
||||||
self.assertEqual(r.dense_shape.get_shape().as_list(), [1])
|
|
||||||
|
|
||||||
# Default shape invariant; b2 adds new values
|
# Default shape invariant; b2 adds new values
|
||||||
_, r = control_flow_ops.while_loop(c, b2, [i, x])
|
_, r = control_flow_ops.while_loop(c, b2, [i, x])
|
||||||
self.assertEqual(r.indices.get_shape().as_list(), [None, 1])
|
check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1])
|
||||||
self.assertEqual(r.values.get_shape().as_list(), [None])
|
|
||||||
self.assertEqual(r.dense_shape.get_shape().as_list(), [1])
|
|
||||||
|
|
||||||
# Default shape invariant; b3 modifies rank (which is not allowed).
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_, r = control_flow_ops.while_loop(c, b3, [i, x])
|
|
||||||
|
|
||||||
# Explicit shape invariant, allowing any rank; b1 only modifies values.
|
# Explicit shape invariant, allowing any rank; b1 only modifies values.
|
||||||
_, r = control_flow_ops.while_loop(
|
_, r = control_flow_ops.while_loop(
|
||||||
c, b1, [i, x],
|
c, b1, [i, x],
|
||||||
[i.get_shape(), tensor_shape.TensorShape([None])])
|
[i.get_shape(), tensor_shape.TensorShape([None])])
|
||||||
self.assertEqual(r.indices.get_shape().as_list(), [None, None])
|
check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
|
||||||
self.assertEqual(r.values.get_shape().as_list(), [None])
|
|
||||||
self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
|
|
||||||
|
|
||||||
# Explicit shape invariant, allowing any rank; b3 modifies rank.
|
# Explicit shape invariant, allowing any rank; b3 modifies rank.
|
||||||
_, r = control_flow_ops.while_loop(
|
_, r = control_flow_ops.while_loop(
|
||||||
c, b3, [i, x],
|
c, b3, [i, x],
|
||||||
[i.get_shape(), tensor_shape.TensorShape([None])])
|
[i.get_shape(), tensor_shape.TensorShape([None])])
|
||||||
self.assertEqual(r.indices.get_shape().as_list(), [None, None])
|
check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
|
||||||
self.assertEqual(r.values.get_shape().as_list(), [None])
|
|
||||||
self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
|
|
||||||
|
|
||||||
# Shape invariant with ndims=None. Technically, this isn't supported
|
# Shape invariant with ndims=None. Technically, this isn't supported
|
||||||
# according to the docs, but we support it for backwards compatibility.
|
# according to the docs, but we support it for backwards compatibility.
|
||||||
_, r = control_flow_ops.while_loop(
|
_, r = control_flow_ops.while_loop(
|
||||||
c, b1, [i, x],
|
c, b1, [i, x],
|
||||||
[i.get_shape(), tensor_shape.TensorShape(None)])
|
[i.get_shape(), tensor_shape.TensorShape(None)])
|
||||||
self.assertEqual(r.indices.get_shape().as_list(), [None, None])
|
check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
|
||||||
self.assertEqual(r.values.get_shape().as_list(), [None])
|
|
||||||
self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
|
|
||||||
_, r = control_flow_ops.while_loop(
|
_, r = control_flow_ops.while_loop(
|
||||||
c, b3, [i, x],
|
c, b3, [i, x],
|
||||||
[i.get_shape(), tensor_shape.TensorShape(None)])
|
[i.get_shape(), tensor_shape.TensorShape(None)])
|
||||||
self.assertEqual(r.indices.get_shape().as_list(), [None, None])
|
check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
|
||||||
self.assertEqual(r.values.get_shape().as_list(), [None])
|
|
||||||
self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
|
@test_util.disable_control_flow_v2("b/131265085")
|
||||||
|
@test_util.run_v1_only("b/131265085")
|
||||||
|
def testWhileBadShapeSparseTensor(self):
|
||||||
|
values = constant_op.constant([2.0, 4.0], name="values")
|
||||||
|
indices = constant_op.constant([[0], [3]],
|
||||||
|
dtype=dtypes.int64,
|
||||||
|
name="indices")
|
||||||
|
shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
|
||||||
|
i = constant_op.constant(0)
|
||||||
|
x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
|
||||||
|
c = lambda i, _: i < 10
|
||||||
|
b1 = lambda i, x: [i+1, x]
|
||||||
|
def b2(i, x): # modifies rank. (shape of all components is changed.)
|
||||||
|
return [
|
||||||
|
i + 1,
|
||||||
|
sparse_tensor.SparseTensor(
|
||||||
|
array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0,
|
||||||
|
array_ops.concat([x.dense_shape, [10]], axis=0))
|
||||||
|
]
|
||||||
|
|
||||||
# Explicit shape invariant, with a specific (incompatible) rank.
|
# Explicit shape invariant, with a specific (incompatible) rank.
|
||||||
with self.assertRaisesRegexp(ValueError, "is not compatible with"):
|
with self.assertRaisesRegexp(ValueError, "is not compatible with"):
|
||||||
_, r = control_flow_ops.while_loop(
|
control_flow_ops.while_loop(
|
||||||
c, b1, [i, x],
|
c, b1, [i, x],
|
||||||
[i.get_shape(), tensor_shape.TensorShape([5])])
|
[i.get_shape(), tensor_shape.TensorShape([5])])
|
||||||
|
|
||||||
@test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
|
# Default shape invariant, but b2 modifies rank (which is not allowed).
|
||||||
@test_util.run_v1_only("b/120545219")
|
with self.assertRaises(ValueError):
|
||||||
|
control_flow_ops.while_loop(c, b2, [i, x])
|
||||||
|
|
||||||
def testWhileShapeInferenceIndexedSlices(self):
|
def testWhileShapeInferenceIndexedSlices(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
|
values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
|
||||||
@ -1953,17 +1979,28 @@ class ControlFlowTest(test.TestCase):
|
|||||||
c, b, [i, x],
|
c, b, [i, x],
|
||||||
[i.get_shape(), tensor_shape.TensorShape([None, 2])])
|
[i.get_shape(), tensor_shape.TensorShape([None, 2])])
|
||||||
self.assertEqual(r.dense_shape.get_shape()[0], 2)
|
self.assertEqual(r.dense_shape.get_shape()[0], 2)
|
||||||
self.assertEqual(r.values.get_shape().as_list(), [None, 2])
|
self.assertTrue(r.values.get_shape().is_compatible_with([None, 2]))
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, "is not compatible with"):
|
@test_util.disable_control_flow_v2("b/131265085")
|
||||||
_, r = control_flow_ops.while_loop(
|
@test_util.run_v1_only("b/131265085")
|
||||||
c, b, [i, x],
|
def testWhileBadShapeIndexedSlices(self):
|
||||||
[i.get_shape(), tensor_shape.TensorShape([None, 5])])
|
values = constant_op.constant([2.0, 4.0], name="values")
|
||||||
|
indices = constant_op.constant([[0], [3]],
|
||||||
|
dtype=dtypes.int64,
|
||||||
|
name="indices")
|
||||||
|
shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
|
||||||
|
i = constant_op.constant(0)
|
||||||
|
x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
|
||||||
|
c = lambda i, _: 10
|
||||||
|
b = lambda i, x: [i+1, x]
|
||||||
|
|
||||||
|
# Explicit shape invariant, with a specific (incompatible) rank.
|
||||||
|
with self.assertRaisesRegexp(ValueError, "is not compatible with"):
|
||||||
|
control_flow_ops.while_loop(
|
||||||
|
c, b, [i, x],
|
||||||
|
[i.get_shape(), tensor_shape.TensorShape([5])])
|
||||||
|
|
||||||
@test_util.disable_control_flow_v2("b/116328420 (RaggedTensor)")
|
|
||||||
def testWhileShapeInferenceRaggedTensor(self):
|
def testWhileShapeInferenceRaggedTensor(self):
|
||||||
if context.executing_eagerly():
|
|
||||||
self.skipTest("b/116328420")
|
|
||||||
i = constant_op.constant(0)
|
i = constant_op.constant(0)
|
||||||
x = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
|
x = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
|
||||||
c = lambda i, _: i < 10
|
c = lambda i, _: i < 10
|
||||||
@ -1980,11 +2017,13 @@ class ControlFlowTest(test.TestCase):
|
|||||||
array_ops.concat([x, x], axis=0)
|
array_ops.concat([x, x], axis=0)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def check_shapes(r, values, splits):
|
||||||
|
self.assertTrue(r.values.shape.is_compatible_with(values))
|
||||||
|
self.assertTrue(r.row_splits.shape.is_compatible_with(splits))
|
||||||
|
|
||||||
# Default shape invariant; b1 adds new values to rows.
|
# Default shape invariant; b1 adds new values to rows.
|
||||||
_, r = control_flow_ops.while_loop(c, b1, [i, x])
|
_, r = control_flow_ops.while_loop(c, b1, [i, x])
|
||||||
self.assertEqual(r.row_splits.shape.as_list(), [4])
|
check_shapes(r, values=[None], splits=[4])
|
||||||
|
|
||||||
self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None]))
|
|
||||||
|
|
||||||
# Default shape invariant; b2 adds new rows (not allowed).
|
# Default shape invariant; b2 adds new rows (not allowed).
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
@ -1995,20 +2034,15 @@ class ControlFlowTest(test.TestCase):
|
|||||||
_, r = control_flow_ops.while_loop(
|
_, r = control_flow_ops.while_loop(
|
||||||
c, b1, [i, x],
|
c, b1, [i, x],
|
||||||
[i.get_shape(), tensor_shape.TensorShape([None, None])])
|
[i.get_shape(), tensor_shape.TensorShape([None, None])])
|
||||||
self.assertTrue(r.row_splits.shape.as_list() in ([4], [None]))
|
check_shapes(r, values=[None], splits=[None])
|
||||||
self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None]))
|
|
||||||
|
|
||||||
# Explicit shape invariant; b2 adds new rows.
|
# Explicit shape invariant; b2 adds new rows.
|
||||||
_, r = control_flow_ops.while_loop(
|
_, r = control_flow_ops.while_loop(
|
||||||
c, b2, [i, x],
|
c, b2, [i, x],
|
||||||
[i.get_shape(), tensor_shape.TensorShape([None, None])])
|
[i.get_shape(), tensor_shape.TensorShape([None, None])])
|
||||||
self.assertTrue(r.row_splits.shape.as_list() in ([3 * 2**10 + 1], [None]))
|
check_shapes(r, values=[None], splits=[None])
|
||||||
self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None]))
|
|
||||||
|
|
||||||
@test_util.disable_control_flow_v2("b/116328420 (RaggedTensor)")
|
|
||||||
def testWhileShapeInferenceRaggedTensorRaggedRank2(self):
|
def testWhileShapeInferenceRaggedTensorRaggedRank2(self):
|
||||||
if context.executing_eagerly():
|
|
||||||
self.skipTest("b/116328420")
|
|
||||||
i = constant_op.constant(0)
|
i = constant_op.constant(0)
|
||||||
x = ragged_factory_ops.constant([[[1, 2], [3], [4, 5, 6]],
|
x = ragged_factory_ops.constant([[[1, 2], [3], [4, 5, 6]],
|
||||||
[[], [8, 9, 10]]])
|
[[], [8, 9, 10]]])
|
||||||
@ -3473,8 +3507,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
self.assertEqual(0, value_x)
|
self.assertEqual(0, value_x)
|
||||||
self.assertEqual(73, value_x_grad)
|
self.assertEqual(73, value_x_grad)
|
||||||
|
|
||||||
@test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
|
@test_util.deprecated_graph_mode_only
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testWhileGrad_IndexedSlices(self):
|
def testWhileGrad_IndexedSlices(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
values = constant_op.constant([2.0, 4.0], name="values")
|
values = constant_op.constant([2.0, 4.0], name="values")
|
||||||
@ -3496,8 +3529,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
r = gradients_impl.gradients(r.values, values)[0]
|
r = gradients_impl.gradients(r.values, values)[0]
|
||||||
self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
|
self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
|
||||||
|
|
||||||
@test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
|
@test_util.deprecated_graph_mode_only
|
||||||
@test_util.run_v1_only("b/120545219")
|
|
||||||
def testWhileGrad_SparseTensor(self):
|
def testWhileGrad_SparseTensor(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
values = constant_op.constant([2.0, 4.0], name="values")
|
values = constant_op.constant([2.0, 4.0], name="values")
|
||||||
@ -3520,7 +3552,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
r = gradients_impl.gradients(r.values, values)[0]
|
r = gradients_impl.gradients(r.values, values)[0]
|
||||||
self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
|
self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.deprecated_graph_mode_only
|
||||||
def testCallGradInLoop(self):
|
def testCallGradInLoop(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
i0 = constant_op.constant(0)
|
i0 = constant_op.constant(0)
|
||||||
|
@ -3466,7 +3466,7 @@ def while_loop(cond,
|
|||||||
return x
|
return x
|
||||||
return ops.convert_to_tensor(x)
|
return ops.convert_to_tensor(x)
|
||||||
|
|
||||||
loop_vars = nest.map_structure(convert, loop_vars)
|
loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True)
|
||||||
if maximum_iterations is not None:
|
if maximum_iterations is not None:
|
||||||
return loop_vars[1]
|
return loop_vars[1]
|
||||||
else:
|
else:
|
||||||
|
@ -72,12 +72,18 @@ def while_loop(cond,
|
|||||||
# `wrapped_body` below.
|
# `wrapped_body` below.
|
||||||
loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
|
loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
|
||||||
loop_vars = nest.map_structure(
|
loop_vars = nest.map_structure(
|
||||||
ops.internal_convert_to_tensor_or_indexed_slices, loop_vars)
|
ops.internal_convert_to_tensor_or_indexed_slices, loop_vars,
|
||||||
|
expand_composites=True)
|
||||||
if shape_invariants is not None:
|
if shape_invariants is not None:
|
||||||
nest.assert_same_structure(orig_loop_vars, shape_invariants)
|
nest.assert_same_structure(orig_loop_vars, shape_invariants,
|
||||||
|
expand_composites=False)
|
||||||
|
shape_invariants = nest.map_structure(
|
||||||
|
control_flow_ops._get_shape_invariant, loop_vars,
|
||||||
|
list(shape_invariants), expand_composites=False)
|
||||||
else:
|
else:
|
||||||
shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars)
|
shape_invariants = nest.map_structure(
|
||||||
|
control_flow_ops._get_shape_invariant, loop_vars,
|
||||||
|
expand_composites=False)
|
||||||
if not name:
|
if not name:
|
||||||
name = "while"
|
name = "while"
|
||||||
|
|
||||||
@ -150,11 +156,12 @@ def while_loop(cond,
|
|||||||
# `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
|
# `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
|
||||||
# and packs it into the structure of `orig_loop_vars`.
|
# and packs it into the structure of `orig_loop_vars`.
|
||||||
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
|
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
|
||||||
if not nest.is_sequence(outputs):
|
if not nest.is_sequence_or_composite(outputs):
|
||||||
outputs = [outputs]
|
outputs = [outputs]
|
||||||
# Compare the structure of input and output of body converting the
|
# Compare the structure of input and output of body converting the
|
||||||
# top-level tuples to list to be compatible with legacy while_loop.
|
# top-level tuples to list to be compatible with legacy while_loop.
|
||||||
nest.assert_same_structure(list(outputs), list(orig_loop_vars))
|
nest.assert_same_structure(list(outputs), list(orig_loop_vars),
|
||||||
|
expand_composites=True)
|
||||||
|
|
||||||
outputs = _tensor_array_to_flow(outputs)
|
outputs = _tensor_array_to_flow(outputs)
|
||||||
|
|
||||||
@ -193,7 +200,8 @@ def while_loop(cond,
|
|||||||
# Make sure that the shapes of the loop outputs are compatible with the
|
# Make sure that the shapes of the loop outputs are compatible with the
|
||||||
# shape invariants, or the shapes of the loop vars if the invariants are not
|
# shape invariants, or the shapes of the loop vars if the invariants are not
|
||||||
# specified.
|
# specified.
|
||||||
num_flattened_outputs = len(nest.flatten(orig_loop_vars))
|
num_flattened_outputs = len(nest.flatten(orig_loop_vars,
|
||||||
|
expand_composites=True))
|
||||||
# First var is loop counter and second var is maximum_iterations.
|
# First var is loop counter and second var is maximum_iterations.
|
||||||
first_loop_var_index = 2
|
first_loop_var_index = 2
|
||||||
_check_shapes_compat(
|
_check_shapes_compat(
|
||||||
@ -201,10 +209,10 @@ def while_loop(cond,
|
|||||||
num_flattened_outputs],
|
num_flattened_outputs],
|
||||||
nest.flatten(
|
nest.flatten(
|
||||||
shape_invariants[first_loop_var_index:first_loop_var_index +
|
shape_invariants[first_loop_var_index:first_loop_var_index +
|
||||||
len_orig_loop_vars]),
|
len_orig_loop_vars], expand_composites=True),
|
||||||
nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
|
nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
|
||||||
len_orig_loop_vars]))
|
len_orig_loop_vars], expand_composites=True))
|
||||||
flattened_loop_vars = nest.flatten(loop_vars)
|
flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
|
||||||
_check_num_inputs_outputs(cond_graph, body_graph,
|
_check_num_inputs_outputs(cond_graph, body_graph,
|
||||||
len(flattened_loop_vars))
|
len(flattened_loop_vars))
|
||||||
|
|
||||||
@ -237,7 +245,7 @@ def while_loop(cond,
|
|||||||
if return_same_structure:
|
if return_same_structure:
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
flattened_outputs = nest.flatten(outputs)
|
flattened_outputs = nest.flatten(outputs, expand_composites=True)
|
||||||
if len(flattened_outputs) == 1:
|
if len(flattened_outputs) == 1:
|
||||||
return flattened_outputs[0]
|
return flattened_outputs[0]
|
||||||
else:
|
else:
|
||||||
@ -905,9 +913,11 @@ def _pack_sequence_as(structure_with_tas, loop_vars):
|
|||||||
|
|
||||||
flattened_loop_vars = [
|
flattened_loop_vars = [
|
||||||
flow_to_tensor_array(*z)
|
flow_to_tensor_array(*z)
|
||||||
for z in zip(nest.flatten(loop_vars), nest.flatten(structure_with_tas))
|
for z in zip(nest.flatten(loop_vars, expand_composites=True),
|
||||||
|
nest.flatten(structure_with_tas, expand_composites=True))
|
||||||
]
|
]
|
||||||
return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars)
|
return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars,
|
||||||
|
expand_composites=True)
|
||||||
|
|
||||||
|
|
||||||
def _tensor_array_to_flow(loop_vars):
|
def _tensor_array_to_flow(loop_vars):
|
||||||
@ -917,14 +927,15 @@ def _tensor_array_to_flow(loop_vars):
|
|||||||
return maybe_ta.flow
|
return maybe_ta.flow
|
||||||
return maybe_ta
|
return maybe_ta
|
||||||
|
|
||||||
return nest.map_structure(f, loop_vars)
|
return nest.map_structure(f, loop_vars, expand_composites=True)
|
||||||
|
|
||||||
|
|
||||||
def _build_signature(loop_vars, shape_invariants):
|
def _build_signature(loop_vars, shape_invariants):
|
||||||
return nest.pack_sequence_as(loop_vars, [
|
return nest.pack_sequence_as(loop_vars, [
|
||||||
tensor_spec.TensorSpec(s, t.dtype, name=t.op.name)
|
tensor_spec.TensorSpec(s, t.dtype, name=t.op.name)
|
||||||
for s, t in zip(nest.flatten(shape_invariants), nest.flatten(loop_vars))
|
for s, t in zip(nest.flatten(shape_invariants, expand_composites=True),
|
||||||
])
|
nest.flatten(loop_vars, expand_composites=True))
|
||||||
|
], expand_composites=True)
|
||||||
|
|
||||||
|
|
||||||
def _build_maximum_iterations_loop_var(maximum_iterations):
|
def _build_maximum_iterations_loop_var(maximum_iterations):
|
||||||
|
Loading…
Reference in New Issue
Block a user