Improve compatibility of while_v2 with XLA tests
Remove assumption where resource variables could not be included as outputs of the body. We instead iterate through the outputs to find the first resource variable index. Also loosen the requirements to specify maximum_iterations for XLA. PiperOrigin-RevId: 226932912
This commit is contained in:
parent
9585116b80
commit
9392ffa092
@ -291,20 +291,15 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
|
|
||||||
xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init);
|
xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init);
|
||||||
|
|
||||||
auto while_shape_or = builder->GetShape(while_result);
|
// Sets non-variable outputs and determine when resource variables start.
|
||||||
OP_REQUIRES_OK(ctx, while_shape_or.status());
|
int resource_index = 0;
|
||||||
auto count = xla::ShapeUtil::TupleElementCount(while_shape_or.ValueOrDie());
|
|
||||||
int max_index = body.outputs.size() + body.resource_updates.size() - 1;
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, max_index < count,
|
|
||||||
errors::Internal("Max tuple element requested (", max_index,
|
|
||||||
") needs to be less than tuple size (", count, ")"));
|
|
||||||
|
|
||||||
// Sets non-variable outputs.
|
|
||||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||||
if (ctx->input_type(i) != DT_RESOURCE) {
|
if (ctx->input_type(i) != DT_RESOURCE) {
|
||||||
ctx->SetOutput(body.input_mapping[i],
|
ctx->SetOutput(body.input_mapping[i],
|
||||||
xla::GetTupleElement(while_result, i));
|
xla::GetTupleElement(while_result, i));
|
||||||
|
++resource_index;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (has_token_input_output_) {
|
if (has_token_input_output_) {
|
||||||
@ -326,7 +321,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
XlaResource* resource;
|
XlaResource* resource;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
|
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
|
||||||
if (update.modified) {
|
if (update.modified) {
|
||||||
int pos = body.outputs.size() + i;
|
int pos = resource_index + i;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
resource->SetFromPack(
|
resource->SetFromPack(
|
||||||
arguments[update.input_index].tensor_array_gradients,
|
arguments[update.input_index].tensor_array_gradients,
|
||||||
|
@ -1183,6 +1183,8 @@ class ControlFlowTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
|
def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
|
||||||
|
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
|
self.skipTest("WhileV2 does lazy evaluation of maximum_iterations")
|
||||||
v = constant_op.constant(1.0)
|
v = constant_op.constant(1.0)
|
||||||
|
|
||||||
def inner_body(i, x):
|
def inner_body(i, x):
|
||||||
@ -1203,44 +1205,27 @@ class ControlFlowTest(test.TestCase):
|
|||||||
gs = gradients_impl.gradients(loop_no_xla, v)
|
gs = gradients_impl.gradients(loop_no_xla, v)
|
||||||
self.evaluate(gs) # This should execute without error.
|
self.evaluate(gs) # This should execute without error.
|
||||||
|
|
||||||
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
xla_context.Enter()
|
||||||
xla_context.Enter()
|
loop_no_maxiter = create_while_loop()
|
||||||
with self.assertRaisesRegexp(
|
loop_with_maxiter = create_while_loop(maximum_iterations=2)
|
||||||
ValueError,
|
xla_context.Exit()
|
||||||
r"maximum_iterations is None. It is required and must be statically "
|
|
||||||
r"known \(e.g. a constant value or known shape dimension\) when "
|
|
||||||
r"building while_loop in XLA context."):
|
|
||||||
loop_no_maxiter = create_while_loop()
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError,
|
|
||||||
r"maximum_iterations must be statically "
|
|
||||||
r"known \(e.g. a constant value or known shape dimension\) when "
|
|
||||||
r"building while_loop in XLA context."):
|
|
||||||
loop_with_maxiter = create_while_loop(maximum_iterations=2)
|
|
||||||
xla_context.Exit()
|
|
||||||
else:
|
|
||||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
|
||||||
xla_context.Enter()
|
|
||||||
loop_no_maxiter = create_while_loop()
|
|
||||||
loop_with_maxiter = create_while_loop(maximum_iterations=2)
|
|
||||||
xla_context.Exit()
|
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
r"Cannot create a gradient accumulator for tensor '.+' inside "
|
r"Cannot create a gradient accumulator for tensor '.+' inside "
|
||||||
r"XLA while_loop because maximum_iterations was not passed to "
|
r"XLA while_loop because maximum_iterations was not passed to "
|
||||||
r"the tf.while_loop call \('.+'\)."):
|
r"the tf.while_loop call \('.+'\)."):
|
||||||
_ = gradients_impl.gradients(loop_no_maxiter, v)
|
_ = gradients_impl.gradients(loop_no_maxiter, v)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError,
|
||||||
r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
|
r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
|
||||||
r"while_loop. maximum_iterations tensor '.+' for while_loop context "
|
r"while_loop. maximum_iterations tensor '.+' for while_loop context "
|
||||||
r"'.+' must be statically known \(e.g. a constant value or known "
|
r"'.+' must be statically known \(e.g. a constant value or known "
|
||||||
r"shape dimension\), or be defined at or outside the while loop "
|
r"shape dimension\), or be defined at or outside the while loop "
|
||||||
r"context '.*' \(currently defined in '.*'\)"):
|
r"context '.*' \(currently defined in '.*'\)"):
|
||||||
_ = gradients_impl.gradients(loop_with_maxiter, v)
|
_ = gradients_impl.gradients(loop_with_maxiter, v)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
|
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
|
||||||
@ -1265,10 +1250,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
xla_context.Enter()
|
xla_context.Enter()
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError, r"Tensor.*Placeholder:0.* must be from the same graph.*"):
|
||||||
r"maximum_iterations must be statically known \(e.g. a constant value"
|
|
||||||
r" or known shape dimension\) when building while_loop in XLA "
|
|
||||||
r"context."):
|
|
||||||
loop = create_while_loop()
|
loop = create_while_loop()
|
||||||
xla_context.Exit()
|
xla_context.Exit()
|
||||||
else:
|
else:
|
||||||
|
@ -254,6 +254,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
|
|||||||
maximum_iterations = op.get_attr(
|
maximum_iterations = op.get_attr(
|
||||||
"_maximum_iterations") if _is_in_xla_context() else None
|
"_maximum_iterations") if _is_in_xla_context() else None
|
||||||
assert not _is_in_xla_context() or maximum_iterations is not None
|
assert not _is_in_xla_context() or maximum_iterations is not None
|
||||||
|
maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)
|
||||||
|
|
||||||
# Set the incoming gradient of non-trainable inputs to None. It is possible
|
# Set the incoming gradient of non-trainable inputs to None. It is possible
|
||||||
# that we receive non-None gradients for non-trainable types in nested while
|
# that we receive non-None gradients for non-trainable types in nested while
|
||||||
@ -376,28 +377,30 @@ def _validate_and_convert_to_tensor(maximum_iterations):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If `maximum_iterations` is invalid.
|
ValueError: If `maximum_iterations` is invalid.
|
||||||
"""
|
"""
|
||||||
if _is_in_xla_context():
|
if maximum_iterations is None:
|
||||||
if maximum_iterations is None:
|
return None
|
||||||
raise ValueError("maximum_iterations is None. It is required and must "
|
|
||||||
"be statically known (e.g. a constant value or known "
|
if _is_in_xla_context() and isinstance(maximum_iterations, ops.Tensor):
|
||||||
"shape dimension) when building while_loop in XLA "
|
# Get the constant value from the `maximum_iterations` tensor to avoid
|
||||||
"context.")
|
# capturing a Const tensor from outside this graph.
|
||||||
if isinstance(maximum_iterations, ops.Tensor):
|
value = tensor_util.constant_value(maximum_iterations)
|
||||||
# Get the constant value from the `maximum_iterations` tensor to avoid
|
if value is None:
|
||||||
# capturing a Const tensor from outside this graph.
|
# XLA requires maximum_iterations to be statically known (e.g. a
|
||||||
maximum_iterations = tensor_util.constant_value(maximum_iterations)
|
# constant value or known shape dimension) when intermediate values
|
||||||
if maximum_iterations is None:
|
# from the forward pass are needed in the gradients pass. However,
|
||||||
raise ValueError("maximum_iterations must be statically known (e.g. a "
|
# maximum_iterations may not be required if the gradient isn't built
|
||||||
"constant value or known shape dimension) when "
|
# or no intermediates are required, thus we return the tensor as is.
|
||||||
"building while_loop in XLA context.")
|
return maximum_iterations
|
||||||
|
|
||||||
|
maximum_iterations = value
|
||||||
|
|
||||||
|
# EmptyTensorList expects `max_num_elements` to be of type int32.
|
||||||
|
maximum_iterations = ops.convert_to_tensor(
|
||||||
|
maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
|
||||||
|
if maximum_iterations.shape.ndims != 0:
|
||||||
|
raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
|
||||||
|
maximum_iterations.shape)
|
||||||
|
|
||||||
if maximum_iterations is not None:
|
|
||||||
# EmptyTensorList expects `max_num_elements` to be of type int32.
|
|
||||||
maximum_iterations = ops.convert_to_tensor(
|
|
||||||
maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
|
|
||||||
if maximum_iterations.shape.ndims != 0:
|
|
||||||
raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
|
|
||||||
maximum_iterations.shape)
|
|
||||||
return maximum_iterations
|
return maximum_iterations
|
||||||
|
|
||||||
|
|
||||||
@ -815,7 +818,7 @@ def _copy_handle_data(src_tensors, tgt_tensors):
|
|||||||
|
|
||||||
|
|
||||||
def _maybe_set_maximum_iterations_attr(op, maximum_iterations):
|
def _maybe_set_maximum_iterations_attr(op, maximum_iterations):
|
||||||
if control_flow_util.IsInXLAContext(op):
|
if maximum_iterations is not None and control_flow_util.IsInXLAContext(op):
|
||||||
# Store the maximum_iterations to use in the gradient pass.
|
# Store the maximum_iterations to use in the gradient pass.
|
||||||
op._set_attr( # pylint: disable=protected-access
|
op._set_attr( # pylint: disable=protected-access
|
||||||
"_maximum_iterations",
|
"_maximum_iterations",
|
||||||
|
Loading…
Reference in New Issue
Block a user