diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index b32683a682c..941b04363f8 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -291,20 +291,15 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init); - auto while_shape_or = builder->GetShape(while_result); - OP_REQUIRES_OK(ctx, while_shape_or.status()); - auto count = xla::ShapeUtil::TupleElementCount(while_shape_or.ValueOrDie()); - int max_index = body.outputs.size() + body.resource_updates.size() - 1; - OP_REQUIRES( - ctx, max_index < count, - errors::Internal("Max tuple element requested (", max_index, - ") needs to be less than tuple size (", count, ")")); - - // Sets non-variable outputs. + // Sets non-variable outputs and determine when resource variables start. + int resource_index = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { ctx->SetOutput(body.input_mapping[i], xla::GetTupleElement(while_result, i)); + ++resource_index; + } else { + break; } } if (has_token_input_output_) { @@ -326,7 +321,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); if (update.modified) { - int pos = body.outputs.size() + i; + int pos = resource_index + i; OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index fa62acbfebf..29490387286 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1183,6 +1183,8 @@ class ControlFlowTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self): + if control_flow_util.ENABLE_CONTROL_FLOW_V2: + self.skipTest("WhileV2 does lazy evaluation of maximum_iterations") v = constant_op.constant(1.0) def inner_body(i, x): @@ -1203,44 +1205,27 @@ class ControlFlowTest(test.TestCase): gs = gradients_impl.gradients(loop_no_xla, v) self.evaluate(gs) # This should execute without error. - if control_flow_util.ENABLE_CONTROL_FLOW_V2: - xla_context = control_flow_ops.XLAControlFlowContext() - xla_context.Enter() - with self.assertRaisesRegexp( - ValueError, - r"maximum_iterations is None. It is required and must be statically " - r"known \(e.g. a constant value or known shape dimension\) when " - r"building while_loop in XLA context."): - loop_no_maxiter = create_while_loop() - with self.assertRaisesRegexp( - ValueError, - r"maximum_iterations must be statically " - r"known \(e.g. a constant value or known shape dimension\) when " - r"building while_loop in XLA context."): - loop_with_maxiter = create_while_loop(maximum_iterations=2) - xla_context.Exit() - else: - xla_context = control_flow_ops.XLAControlFlowContext() - xla_context.Enter() - loop_no_maxiter = create_while_loop() - loop_with_maxiter = create_while_loop(maximum_iterations=2) - xla_context.Exit() + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + loop_no_maxiter = create_while_loop() + loop_with_maxiter = create_while_loop(maximum_iterations=2) + xla_context.Exit() - with self.assertRaisesRegexp( - ValueError, - r"Cannot create a gradient accumulator for tensor '.+' inside " - r"XLA while_loop because maximum_iterations was not passed to " - r"the tf.while_loop call \('.+'\)."): - _ = gradients_impl.gradients(loop_no_maxiter, v) + with self.assertRaisesRegexp( + ValueError, + r"Cannot create a gradient accumulator for tensor '.+' inside " + r"XLA while_loop because maximum_iterations was not passed to " + r"the tf.while_loop call \('.+'\)."): + _ = gradients_impl.gradients(loop_no_maxiter, v) - with self.assertRaisesRegexp( - ValueError, - r"Cannot create a gradient accumulator for tensor '.+' inside XLA " - r"while_loop. maximum_iterations tensor '.+' for while_loop context " - r"'.+' must be statically known \(e.g. a constant value or known " - r"shape dimension\), or be defined at or outside the while loop " - r"context '.*' \(currently defined in '.*'\)"): - _ = gradients_impl.gradients(loop_with_maxiter, v) + with self.assertRaisesRegexp( + ValueError, + r"Cannot create a gradient accumulator for tensor '.+' inside XLA " + r"while_loop. maximum_iterations tensor '.+' for while_loop context " + r"'.+' must be statically known \(e.g. a constant value or known " + r"shape dimension\), or be defined at or outside the while loop " + r"context '.*' \(currently defined in '.*'\)"): + _ = gradients_impl.gradients(loop_with_maxiter, v) @test_util.run_v1_only("b/120545219") def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self): @@ -1265,10 +1250,7 @@ class ControlFlowTest(test.TestCase): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() with self.assertRaisesRegexp( - ValueError, - r"maximum_iterations must be statically known \(e.g. a constant value" - r" or known shape dimension\) when building while_loop in XLA " - r"context."): + ValueError, r"Tensor.*Placeholder:0.* must be from the same graph.*"): loop = create_while_loop() xla_context.Exit() else: diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 295686f8143..8547cb94a03 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -254,6 +254,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None assert not _is_in_xla_context() or maximum_iterations is not None + maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) # Set the incoming gradient of non-trainable inputs to None. It is possible # that we receive non-None gradients for non-trainable types in nested while @@ -376,28 +377,30 @@ def _validate_and_convert_to_tensor(maximum_iterations): Raises: ValueError: If `maximum_iterations` is invalid. """ - if _is_in_xla_context(): - if maximum_iterations is None: - raise ValueError("maximum_iterations is None. It is required and must " - "be statically known (e.g. a constant value or known " - "shape dimension) when building while_loop in XLA " - "context.") - if isinstance(maximum_iterations, ops.Tensor): - # Get the constant value from the `maximum_iterations` tensor to avoid - # capturing a Const tensor from outside this graph. - maximum_iterations = tensor_util.constant_value(maximum_iterations) - if maximum_iterations is None: - raise ValueError("maximum_iterations must be statically known (e.g. a " - "constant value or known shape dimension) when " - "building while_loop in XLA context.") + if maximum_iterations is None: + return None + + if _is_in_xla_context() and isinstance(maximum_iterations, ops.Tensor): + # Get the constant value from the `maximum_iterations` tensor to avoid + # capturing a Const tensor from outside this graph. + value = tensor_util.constant_value(maximum_iterations) + if value is None: + # XLA requires maximum_iterations to be statically known (e.g. a + # constant value or known shape dimension) when intermediate values + # from the forward pass are needed in the gradients pass. However, + # maximum_iterations may not be required if the gradient isn't built + # or no intermediates are required, thus we return the tensor as is. + return maximum_iterations + + maximum_iterations = value + + # EmptyTensorList expects `max_num_elements` to be of type int32. + maximum_iterations = ops.convert_to_tensor( + maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") + if maximum_iterations.shape.ndims != 0: + raise ValueError("maximum_iterations must be a scalar, saw shape: %s" % + maximum_iterations.shape) - if maximum_iterations is not None: - # EmptyTensorList expects `max_num_elements` to be of type int32. - maximum_iterations = ops.convert_to_tensor( - maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") - if maximum_iterations.shape.ndims != 0: - raise ValueError("maximum_iterations must be a scalar, saw shape: %s" % - maximum_iterations.shape) return maximum_iterations @@ -815,7 +818,7 @@ def _copy_handle_data(src_tensors, tgt_tensors): def _maybe_set_maximum_iterations_attr(op, maximum_iterations): - if control_flow_util.IsInXLAContext(op): + if maximum_iterations is not None and control_flow_util.IsInXLAContext(op): # Store the maximum_iterations to use in the gradient pass. op._set_attr( # pylint: disable=protected-access "_maximum_iterations",