[TF2XLA] Support asserts as no-ops for tf.function(jit_compile=True)

This is consistent with TPU and tfcompile behavior.

PiperOrigin-RevId: 348131078
Change-Id: I4f8807af60fdbe1d79f0e79db59df0feed79c94f
This commit is contained in:
George Karpenkov 2020-12-17 18:40:12 -08:00 committed by TensorFlower Gardener
parent edf4876054
commit 00475eda7c
6 changed files with 36 additions and 7 deletions

View File

@ -184,6 +184,7 @@ XLA_DEVICE_DEPS = [
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_op_registry",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:client_library",

View File

@ -196,12 +196,11 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(
"SymbolicGradient should be handled by IsCompilableCall().";
return false;
}
if (node.type_string() == "Const") {
// Skip Const op with type DT_STRING, since XLA doesn't support it, but the
// registered Const KernelDef says that it does, to support no-op Assert for
// tfcompile.
const AttrValue* attr = node.attrs().Find("dtype");
if (attr != nullptr && attr->type() == DT_STRING) {
if (!op_filter_.allow_string_consts && attr != nullptr &&
attr->type() == DT_STRING) {
*uncompilable_reason =
"Const op with type DT_STRING is not supported by XLA.";
return false;

View File

@ -129,6 +129,9 @@ class RecursiveCompilabilityChecker {
// Require the function to be always compilable, regardless whether some
// control flow branches might be dead for a given input.
bool require_always_compilable = false;
// Whether string constants are compilable.
bool allow_string_consts = true;
};
RecursiveCompilabilityChecker(OperationFilter op_filter,

View File

@ -1199,6 +1199,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
RecursiveCompilabilityChecker::OperationFilter filter =
CreateOperationFilter(*registration);
filter.require_always_compilable = true;
filter.allow_string_consts = false;
RecursiveCompilabilityChecker checker(
filter, DeviceType{registration->compilation_device_name});
@ -1207,6 +1208,15 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
continue;
}
if (node->type_string() == "Const") {
// Skip Const op with type DT_STRING, since XLA autoclustering doesn't
// support it.
const AttrValue* attr = node->attrs().Find("dtype");
if (attr != nullptr && attr->type() == DT_STRING) {
continue;
}
}
if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
VLOG(1) << "Rejecting TF operation " << node->def().op()
<< " as it is not listed in --tf_xla_ops_to_cluster.";

View File

@ -3,7 +3,7 @@
Compilation with XLA can greatly improve the performance of your programs, but
the TensorFlow interop has a number of known sharp corners.
## TensorArray TF/XLA interconversion
## TensorArray TF/XLA interconversion is not supported
*Error message*:
`Support for TensorList crossing the XLA/TF boundary is not implemented`.
@ -31,7 +31,7 @@ intermediate results in a `TensorArray`, but XLA only supports bounded
parameter set to a constant value known at compile time, or backpropagation
disabled using `back_prop=False`.
## Dynamic `tf.TensorArray`
## Dynamic `tf.TensorArray` is not supported
Writes into `tf.TensorArray(..., dynamic_size=True)` are not compilable with
XLA, as such writes require an unknown number of reallocations when the array
@ -39,9 +39,16 @@ exceeds the original bound.
*Workaround*: provide a statically known bound to your arrays.
## Random number generation
## Random number generation ignores TF seed
XLA currently ignores TF seeds to random operations. This affects stateful TF
random operations, such as `tf.random.normal`, or `tf.nn.dropout`. XLA will
behave as if the compilation was seeded with a new unique seed at each run. This
limitation does not apply to stateless random ops.
## TensorFlow Asserts are ignored
Assertions created using `tf.Assert` and similar functions are noops when
compiled to XLA. While proper assertion support is in principle possible, it
might make certain optimizations impossible (mainly fusing the buffer on which
the assertion is performed).

View File

@ -862,6 +862,15 @@ class DefFunctionTest(xla_test.XLATestCase):
self.assertEqual(out.shape[0], 50)
self.assertEqual(out.shape[1], 2)
def testTfAssert(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(jit_compile=True)
def f(x):
control_flow_ops.Assert(x == 1, ['Wrong value'])
f(constant_op.constant(1))
if __name__ == '__main__':
ops.enable_eager_execution()