From 9392ffa09224f0a7735aa7076bee2024c39f1e69 Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Wed, 26 Dec 2018 10:46:00 -0800 Subject: [PATCH] Improve compatibility of while_v2 with XLA tests Remove assumption where resource variables could not be included as outputs of the body. We instead iterate through the outputs to find the first resource variable index. Also loosen the requirements to specify maximum_iterations for XLA. PiperOrigin-RevId: 226932912 --- .../compiler/tf2xla/kernels/while_op.cc | 17 ++--- .../kernel_tests/control_flow_ops_py_test.py | 62 +++++++------------ tensorflow/python/ops/while_v2.py | 47 +++++++------- 3 files changed, 53 insertions(+), 73 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index b32683a682c..941b04363f8 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -291,20 +291,15 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init); - auto while_shape_or = builder->GetShape(while_result); - OP_REQUIRES_OK(ctx, while_shape_or.status()); - auto count = xla::ShapeUtil::TupleElementCount(while_shape_or.ValueOrDie()); - int max_index = body.outputs.size() + body.resource_updates.size() - 1; - OP_REQUIRES( - ctx, max_index < count, - errors::Internal("Max tuple element requested (", max_index, - ") needs to be less than tuple size (", count, ")")); - - // Sets non-variable outputs. + // Sets non-variable outputs and determine when resource variables start. + int resource_index = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { ctx->SetOutput(body.input_mapping[i], xla::GetTupleElement(while_result, i)); + ++resource_index; + } else { + break; } } if (has_token_input_output_) { @@ -326,7 +321,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); if (update.modified) { - int pos = body.outputs.size() + i; + int pos = resource_index + i; OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index fa62acbfebf..29490387286 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1183,6 +1183,8 @@ class ControlFlowTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self): + if control_flow_util.ENABLE_CONTROL_FLOW_V2: + self.skipTest("WhileV2 does lazy evaluation of maximum_iterations") v = constant_op.constant(1.0) def inner_body(i, x): @@ -1203,44 +1205,27 @@ class ControlFlowTest(test.TestCase): gs = gradients_impl.gradients(loop_no_xla, v) self.evaluate(gs) # This should execute without error. - if control_flow_util.ENABLE_CONTROL_FLOW_V2: - xla_context = control_flow_ops.XLAControlFlowContext() - xla_context.Enter() - with self.assertRaisesRegexp( - ValueError, - r"maximum_iterations is None. It is required and must be statically " - r"known \(e.g. a constant value or known shape dimension\) when " - r"building while_loop in XLA context."): - loop_no_maxiter = create_while_loop() - with self.assertRaisesRegexp( - ValueError, - r"maximum_iterations must be statically " - r"known \(e.g. a constant value or known shape dimension\) when " - r"building while_loop in XLA context."): - loop_with_maxiter = create_while_loop(maximum_iterations=2) - xla_context.Exit() - else: - xla_context = control_flow_ops.XLAControlFlowContext() - xla_context.Enter() - loop_no_maxiter = create_while_loop() - loop_with_maxiter = create_while_loop(maximum_iterations=2) - xla_context.Exit() + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + loop_no_maxiter = create_while_loop() + loop_with_maxiter = create_while_loop(maximum_iterations=2) + xla_context.Exit() - with self.assertRaisesRegexp( - ValueError, - r"Cannot create a gradient accumulator for tensor '.+' inside " - r"XLA while_loop because maximum_iterations was not passed to " - r"the tf.while_loop call \('.+'\)."): - _ = gradients_impl.gradients(loop_no_maxiter, v) + with self.assertRaisesRegexp( + ValueError, + r"Cannot create a gradient accumulator for tensor '.+' inside " + r"XLA while_loop because maximum_iterations was not passed to " + r"the tf.while_loop call \('.+'\)."): + _ = gradients_impl.gradients(loop_no_maxiter, v) - with self.assertRaisesRegexp( - ValueError, - r"Cannot create a gradient accumulator for tensor '.+' inside XLA " - r"while_loop. maximum_iterations tensor '.+' for while_loop context " - r"'.+' must be statically known \(e.g. a constant value or known " - r"shape dimension\), or be defined at or outside the while loop " - r"context '.*' \(currently defined in '.*'\)"): - _ = gradients_impl.gradients(loop_with_maxiter, v) + with self.assertRaisesRegexp( + ValueError, + r"Cannot create a gradient accumulator for tensor '.+' inside XLA " + r"while_loop. maximum_iterations tensor '.+' for while_loop context " + r"'.+' must be statically known \(e.g. a constant value or known " + r"shape dimension\), or be defined at or outside the while loop " + r"context '.*' \(currently defined in '.*'\)"): + _ = gradients_impl.gradients(loop_with_maxiter, v) @test_util.run_v1_only("b/120545219") def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self): @@ -1265,10 +1250,7 @@ class ControlFlowTest(test.TestCase): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() with self.assertRaisesRegexp( - ValueError, - r"maximum_iterations must be statically known \(e.g. a constant value" - r" or known shape dimension\) when building while_loop in XLA " - r"context."): + ValueError, r"Tensor.*Placeholder:0.* must be from the same graph.*"): loop = create_while_loop() xla_context.Exit() else: diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 295686f8143..8547cb94a03 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -254,6 +254,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None assert not _is_in_xla_context() or maximum_iterations is not None + maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) # Set the incoming gradient of non-trainable inputs to None. It is possible # that we receive non-None gradients for non-trainable types in nested while @@ -376,28 +377,30 @@ def _validate_and_convert_to_tensor(maximum_iterations): Raises: ValueError: If `maximum_iterations` is invalid. """ - if _is_in_xla_context(): - if maximum_iterations is None: - raise ValueError("maximum_iterations is None. It is required and must " - "be statically known (e.g. a constant value or known " - "shape dimension) when building while_loop in XLA " - "context.") - if isinstance(maximum_iterations, ops.Tensor): - # Get the constant value from the `maximum_iterations` tensor to avoid - # capturing a Const tensor from outside this graph. - maximum_iterations = tensor_util.constant_value(maximum_iterations) - if maximum_iterations is None: - raise ValueError("maximum_iterations must be statically known (e.g. a " - "constant value or known shape dimension) when " - "building while_loop in XLA context.") + if maximum_iterations is None: + return None + + if _is_in_xla_context() and isinstance(maximum_iterations, ops.Tensor): + # Get the constant value from the `maximum_iterations` tensor to avoid + # capturing a Const tensor from outside this graph. + value = tensor_util.constant_value(maximum_iterations) + if value is None: + # XLA requires maximum_iterations to be statically known (e.g. a + # constant value or known shape dimension) when intermediate values + # from the forward pass are needed in the gradients pass. However, + # maximum_iterations may not be required if the gradient isn't built + # or no intermediates are required, thus we return the tensor as is. + return maximum_iterations + + maximum_iterations = value + + # EmptyTensorList expects `max_num_elements` to be of type int32. + maximum_iterations = ops.convert_to_tensor( + maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") + if maximum_iterations.shape.ndims != 0: + raise ValueError("maximum_iterations must be a scalar, saw shape: %s" % + maximum_iterations.shape) - if maximum_iterations is not None: - # EmptyTensorList expects `max_num_elements` to be of type int32. - maximum_iterations = ops.convert_to_tensor( - maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") - if maximum_iterations.shape.ndims != 0: - raise ValueError("maximum_iterations must be a scalar, saw shape: %s" % - maximum_iterations.shape) return maximum_iterations @@ -815,7 +818,7 @@ def _copy_handle_data(src_tensors, tgt_tensors): def _maybe_set_maximum_iterations_attr(op, maximum_iterations): - if control_flow_util.IsInXLAContext(op): + if maximum_iterations is not None and control_flow_util.IsInXLAContext(op): # Store the maximum_iterations to use in the gradient pass. op._set_attr( # pylint: disable=protected-access "_maximum_iterations",