Improve compatibility of while_v2 with XLA tests
Remove assumption where resource variables could not be included as outputs of the body. We instead iterate through the outputs to find the first resource variable index. Also loosen the requirements to specify maximum_iterations for XLA. PiperOrigin-RevId: 226932912
This commit is contained in:
parent
9585116b80
commit
9392ffa092
@ -291,20 +291,15 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
|
||||
xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init);
|
||||
|
||||
auto while_shape_or = builder->GetShape(while_result);
|
||||
OP_REQUIRES_OK(ctx, while_shape_or.status());
|
||||
auto count = xla::ShapeUtil::TupleElementCount(while_shape_or.ValueOrDie());
|
||||
int max_index = body.outputs.size() + body.resource_updates.size() - 1;
|
||||
OP_REQUIRES(
|
||||
ctx, max_index < count,
|
||||
errors::Internal("Max tuple element requested (", max_index,
|
||||
") needs to be less than tuple size (", count, ")"));
|
||||
|
||||
// Sets non-variable outputs.
|
||||
// Sets non-variable outputs and determine when resource variables start.
|
||||
int resource_index = 0;
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
if (ctx->input_type(i) != DT_RESOURCE) {
|
||||
ctx->SetOutput(body.input_mapping[i],
|
||||
xla::GetTupleElement(while_result, i));
|
||||
++resource_index;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_token_input_output_) {
|
||||
@ -326,7 +321,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
|
||||
if (update.modified) {
|
||||
int pos = body.outputs.size() + i;
|
||||
int pos = resource_index + i;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
resource->SetFromPack(
|
||||
arguments[update.input_index].tensor_array_gradients,
|
||||
|
@ -1183,6 +1183,8 @@ class ControlFlowTest(test.TestCase):
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
|
||||
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
self.skipTest("WhileV2 does lazy evaluation of maximum_iterations")
|
||||
v = constant_op.constant(1.0)
|
||||
|
||||
def inner_body(i, x):
|
||||
@ -1203,23 +1205,6 @@ class ControlFlowTest(test.TestCase):
|
||||
gs = gradients_impl.gradients(loop_no_xla, v)
|
||||
self.evaluate(gs) # This should execute without error.
|
||||
|
||||
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"maximum_iterations is None. It is required and must be statically "
|
||||
r"known \(e.g. a constant value or known shape dimension\) when "
|
||||
r"building while_loop in XLA context."):
|
||||
loop_no_maxiter = create_while_loop()
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"maximum_iterations must be statically "
|
||||
r"known \(e.g. a constant value or known shape dimension\) when "
|
||||
r"building while_loop in XLA context."):
|
||||
loop_with_maxiter = create_while_loop(maximum_iterations=2)
|
||||
xla_context.Exit()
|
||||
else:
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
loop_no_maxiter = create_while_loop()
|
||||
@ -1265,10 +1250,7 @@ class ControlFlowTest(test.TestCase):
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"maximum_iterations must be statically known \(e.g. a constant value"
|
||||
r" or known shape dimension\) when building while_loop in XLA "
|
||||
r"context."):
|
||||
ValueError, r"Tensor.*Placeholder:0.* must be from the same graph.*"):
|
||||
loop = create_while_loop()
|
||||
xla_context.Exit()
|
||||
else:
|
||||
|
@ -254,6 +254,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
|
||||
maximum_iterations = op.get_attr(
|
||||
"_maximum_iterations") if _is_in_xla_context() else None
|
||||
assert not _is_in_xla_context() or maximum_iterations is not None
|
||||
maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)
|
||||
|
||||
# Set the incoming gradient of non-trainable inputs to None. It is possible
|
||||
# that we receive non-None gradients for non-trainable types in nested while
|
||||
@ -376,28 +377,30 @@ def _validate_and_convert_to_tensor(maximum_iterations):
|
||||
Raises:
|
||||
ValueError: If `maximum_iterations` is invalid.
|
||||
"""
|
||||
if _is_in_xla_context():
|
||||
if maximum_iterations is None:
|
||||
raise ValueError("maximum_iterations is None. It is required and must "
|
||||
"be statically known (e.g. a constant value or known "
|
||||
"shape dimension) when building while_loop in XLA "
|
||||
"context.")
|
||||
if isinstance(maximum_iterations, ops.Tensor):
|
||||
return None
|
||||
|
||||
if _is_in_xla_context() and isinstance(maximum_iterations, ops.Tensor):
|
||||
# Get the constant value from the `maximum_iterations` tensor to avoid
|
||||
# capturing a Const tensor from outside this graph.
|
||||
maximum_iterations = tensor_util.constant_value(maximum_iterations)
|
||||
if maximum_iterations is None:
|
||||
raise ValueError("maximum_iterations must be statically known (e.g. a "
|
||||
"constant value or known shape dimension) when "
|
||||
"building while_loop in XLA context.")
|
||||
value = tensor_util.constant_value(maximum_iterations)
|
||||
if value is None:
|
||||
# XLA requires maximum_iterations to be statically known (e.g. a
|
||||
# constant value or known shape dimension) when intermediate values
|
||||
# from the forward pass are needed in the gradients pass. However,
|
||||
# maximum_iterations may not be required if the gradient isn't built
|
||||
# or no intermediates are required, thus we return the tensor as is.
|
||||
return maximum_iterations
|
||||
|
||||
maximum_iterations = value
|
||||
|
||||
if maximum_iterations is not None:
|
||||
# EmptyTensorList expects `max_num_elements` to be of type int32.
|
||||
maximum_iterations = ops.convert_to_tensor(
|
||||
maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
|
||||
if maximum_iterations.shape.ndims != 0:
|
||||
raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
|
||||
maximum_iterations.shape)
|
||||
|
||||
return maximum_iterations
|
||||
|
||||
|
||||
@ -815,7 +818,7 @@ def _copy_handle_data(src_tensors, tgt_tensors):
|
||||
|
||||
|
||||
def _maybe_set_maximum_iterations_attr(op, maximum_iterations):
|
||||
if control_flow_util.IsInXLAContext(op):
|
||||
if maximum_iterations is not None and control_flow_util.IsInXLAContext(op):
|
||||
# Store the maximum_iterations to use in the gradient pass.
|
||||
op._set_attr( # pylint: disable=protected-access
|
||||
"_maximum_iterations",
|
||||
|
Loading…
Reference in New Issue
Block a user