Improve compatibility of while_v2 with XLA tests

Remove assumption where resource variables could not be included as
outputs of the body. We instead iterate through the outputs to find
the first resource variable index.
Also loosen the requirements to specify maximum_iterations for XLA.

PiperOrigin-RevId: 226932912
This commit is contained in:
Gaurav Jain 2018-12-26 10:46:00 -08:00 committed by TensorFlower Gardener
parent 9585116b80
commit 9392ffa092
3 changed files with 53 additions and 73 deletions

View File

@ -291,20 +291,15 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init);
auto while_shape_or = builder->GetShape(while_result);
OP_REQUIRES_OK(ctx, while_shape_or.status());
auto count = xla::ShapeUtil::TupleElementCount(while_shape_or.ValueOrDie());
int max_index = body.outputs.size() + body.resource_updates.size() - 1;
OP_REQUIRES(
ctx, max_index < count,
errors::Internal("Max tuple element requested (", max_index,
") needs to be less than tuple size (", count, ")"));
// Sets non-variable outputs.
// Sets non-variable outputs and determine when resource variables start.
int resource_index = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
ctx->SetOutput(body.input_mapping[i],
xla::GetTupleElement(while_result, i));
++resource_index;
} else {
break;
}
}
if (has_token_input_output_) {
@ -326,7 +321,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
if (update.modified) {
int pos = body.outputs.size() + i;
int pos = resource_index + i;
OP_REQUIRES_OK(ctx,
resource->SetFromPack(
arguments[update.input_index].tensor_array_gradients,

View File

@ -1183,6 +1183,8 @@ class ControlFlowTest(test.TestCase):
@test_util.run_v1_only("b/120545219")
def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.skipTest("WhileV2 does lazy evaluation of maximum_iterations")
v = constant_op.constant(1.0)
def inner_body(i, x):
@ -1203,44 +1205,27 @@ class ControlFlowTest(test.TestCase):
gs = gradients_impl.gradients(loop_no_xla, v)
self.evaluate(gs) # This should execute without error.
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
with self.assertRaisesRegexp(
ValueError,
r"maximum_iterations is None. It is required and must be statically "
r"known \(e.g. a constant value or known shape dimension\) when "
r"building while_loop in XLA context."):
loop_no_maxiter = create_while_loop()
with self.assertRaisesRegexp(
ValueError,
r"maximum_iterations must be statically "
r"known \(e.g. a constant value or known shape dimension\) when "
r"building while_loop in XLA context."):
loop_with_maxiter = create_while_loop(maximum_iterations=2)
xla_context.Exit()
else:
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
loop_no_maxiter = create_while_loop()
loop_with_maxiter = create_while_loop(maximum_iterations=2)
xla_context.Exit()
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
loop_no_maxiter = create_while_loop()
loop_with_maxiter = create_while_loop(maximum_iterations=2)
xla_context.Exit()
with self.assertRaisesRegexp(
ValueError,
r"Cannot create a gradient accumulator for tensor '.+' inside "
r"XLA while_loop because maximum_iterations was not passed to "
r"the tf.while_loop call \('.+'\)."):
_ = gradients_impl.gradients(loop_no_maxiter, v)
with self.assertRaisesRegexp(
ValueError,
r"Cannot create a gradient accumulator for tensor '.+' inside "
r"XLA while_loop because maximum_iterations was not passed to "
r"the tf.while_loop call \('.+'\)."):
_ = gradients_impl.gradients(loop_no_maxiter, v)
with self.assertRaisesRegexp(
ValueError,
r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
r"while_loop. maximum_iterations tensor '.+' for while_loop context "
r"'.+' must be statically known \(e.g. a constant value or known "
r"shape dimension\), or be defined at or outside the while loop "
r"context '.*' \(currently defined in '.*'\)"):
_ = gradients_impl.gradients(loop_with_maxiter, v)
with self.assertRaisesRegexp(
ValueError,
r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
r"while_loop. maximum_iterations tensor '.+' for while_loop context "
r"'.+' must be statically known \(e.g. a constant value or known "
r"shape dimension\), or be defined at or outside the while loop "
r"context '.*' \(currently defined in '.*'\)"):
_ = gradients_impl.gradients(loop_with_maxiter, v)
@test_util.run_v1_only("b/120545219")
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
@ -1265,10 +1250,7 @@ class ControlFlowTest(test.TestCase):
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
with self.assertRaisesRegexp(
ValueError,
r"maximum_iterations must be statically known \(e.g. a constant value"
r" or known shape dimension\) when building while_loop in XLA "
r"context."):
ValueError, r"Tensor.*Placeholder:0.* must be from the same graph.*"):
loop = create_while_loop()
xla_context.Exit()
else:

View File

@ -254,6 +254,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
maximum_iterations = op.get_attr(
"_maximum_iterations") if _is_in_xla_context() else None
assert not _is_in_xla_context() or maximum_iterations is not None
maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)
# Set the incoming gradient of non-trainable inputs to None. It is possible
# that we receive non-None gradients for non-trainable types in nested while
@ -376,28 +377,30 @@ def _validate_and_convert_to_tensor(maximum_iterations):
Raises:
ValueError: If `maximum_iterations` is invalid.
"""
if _is_in_xla_context():
if maximum_iterations is None:
raise ValueError("maximum_iterations is None. It is required and must "
"be statically known (e.g. a constant value or known "
"shape dimension) when building while_loop in XLA "
"context.")
if isinstance(maximum_iterations, ops.Tensor):
# Get the constant value from the `maximum_iterations` tensor to avoid
# capturing a Const tensor from outside this graph.
maximum_iterations = tensor_util.constant_value(maximum_iterations)
if maximum_iterations is None:
raise ValueError("maximum_iterations must be statically known (e.g. a "
"constant value or known shape dimension) when "
"building while_loop in XLA context.")
if maximum_iterations is None:
return None
if _is_in_xla_context() and isinstance(maximum_iterations, ops.Tensor):
# Get the constant value from the `maximum_iterations` tensor to avoid
# capturing a Const tensor from outside this graph.
value = tensor_util.constant_value(maximum_iterations)
if value is None:
# XLA requires maximum_iterations to be statically known (e.g. a
# constant value or known shape dimension) when intermediate values
# from the forward pass are needed in the gradients pass. However,
# maximum_iterations may not be required if the gradient isn't built
# or no intermediates are required, thus we return the tensor as is.
return maximum_iterations
maximum_iterations = value
# EmptyTensorList expects `max_num_elements` to be of type int32.
maximum_iterations = ops.convert_to_tensor(
maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
if maximum_iterations.shape.ndims != 0:
raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
maximum_iterations.shape)
if maximum_iterations is not None:
# EmptyTensorList expects `max_num_elements` to be of type int32.
maximum_iterations = ops.convert_to_tensor(
maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
if maximum_iterations.shape.ndims != 0:
raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
maximum_iterations.shape)
return maximum_iterations
@ -815,7 +818,7 @@ def _copy_handle_data(src_tensors, tgt_tensors):
def _maybe_set_maximum_iterations_attr(op, maximum_iterations):
if control_flow_util.IsInXLAContext(op):
if maximum_iterations is not None and control_flow_util.IsInXLAContext(op):
# Store the maximum_iterations to use in the gradient pass.
op._set_attr( # pylint: disable=protected-access
"_maximum_iterations",