[TF2XLA] Support asserts as no-ops for tf.function(jit_compile=True)
This is consistent with TPU and tfcompile behavior. PiperOrigin-RevId: 348131078 Change-Id: I4f8807af60fdbe1d79f0e79db59df0feed79c94f
This commit is contained in:
parent
edf4876054
commit
00475eda7c
@ -184,6 +184,7 @@ XLA_DEVICE_DEPS = [
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
|
@ -196,12 +196,11 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(
|
||||
"SymbolicGradient should be handled by IsCompilableCall().";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node.type_string() == "Const") {
|
||||
// Skip Const op with type DT_STRING, since XLA doesn't support it, but the
|
||||
// registered Const KernelDef says that it does, to support no-op Assert for
|
||||
// tfcompile.
|
||||
const AttrValue* attr = node.attrs().Find("dtype");
|
||||
if (attr != nullptr && attr->type() == DT_STRING) {
|
||||
if (!op_filter_.allow_string_consts && attr != nullptr &&
|
||||
attr->type() == DT_STRING) {
|
||||
*uncompilable_reason =
|
||||
"Const op with type DT_STRING is not supported by XLA.";
|
||||
return false;
|
||||
|
@ -129,6 +129,9 @@ class RecursiveCompilabilityChecker {
|
||||
// Require the function to be always compilable, regardless whether some
|
||||
// control flow branches might be dead for a given input.
|
||||
bool require_always_compilable = false;
|
||||
|
||||
// Whether string constants are compilable.
|
||||
bool allow_string_consts = true;
|
||||
};
|
||||
|
||||
RecursiveCompilabilityChecker(OperationFilter op_filter,
|
||||
|
@ -1199,6 +1199,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
RecursiveCompilabilityChecker::OperationFilter filter =
|
||||
CreateOperationFilter(*registration);
|
||||
filter.require_always_compilable = true;
|
||||
filter.allow_string_consts = false;
|
||||
|
||||
RecursiveCompilabilityChecker checker(
|
||||
filter, DeviceType{registration->compilation_device_name});
|
||||
@ -1207,6 +1208,15 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->type_string() == "Const") {
|
||||
// Skip Const op with type DT_STRING, since XLA autoclustering doesn't
|
||||
// support it.
|
||||
const AttrValue* attr = node->attrs().Find("dtype");
|
||||
if (attr != nullptr && attr->type() == DT_STRING) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
|
||||
VLOG(1) << "Rejecting TF operation " << node->def().op()
|
||||
<< " as it is not listed in --tf_xla_ops_to_cluster.";
|
||||
|
@ -3,7 +3,7 @@
|
||||
Compilation with XLA can greatly improve the performance of your programs, but
|
||||
the TensorFlow interop has a number of known sharp corners.
|
||||
|
||||
## TensorArray TF/XLA interconversion
|
||||
## TensorArray TF/XLA interconversion is not supported
|
||||
|
||||
*Error message*:
|
||||
`Support for TensorList crossing the XLA/TF boundary is not implemented`.
|
||||
@ -31,7 +31,7 @@ intermediate results in a `TensorArray`, but XLA only supports bounded
|
||||
parameter set to a constant value known at compile time, or backpropagation
|
||||
disabled using `back_prop=False`.
|
||||
|
||||
## Dynamic `tf.TensorArray`
|
||||
## Dynamic `tf.TensorArray` is not supported
|
||||
|
||||
Writes into `tf.TensorArray(..., dynamic_size=True)` are not compilable with
|
||||
XLA, as such writes require an unknown number of reallocations when the array
|
||||
@ -39,9 +39,16 @@ exceeds the original bound.
|
||||
|
||||
*Workaround*: provide a statically known bound to your arrays.
|
||||
|
||||
## Random number generation
|
||||
## Random number generation ignores TF seed
|
||||
|
||||
XLA currently ignores TF seeds to random operations. This affects stateful TF
|
||||
random operations, such as `tf.random.normal`, or `tf.nn.dropout`. XLA will
|
||||
behave as if the compilation was seeded with a new unique seed at each run. This
|
||||
limitation does not apply to stateless random ops.
|
||||
|
||||
## TensorFlow Asserts are ignored
|
||||
|
||||
Assertions created using `tf.Assert` and similar functions are noops when
|
||||
compiled to XLA. While proper assertion support is in principle possible, it
|
||||
might make certain optimizations impossible (mainly fusing the buffer on which
|
||||
the assertion is performed).
|
||||
|
@ -862,6 +862,15 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
self.assertEqual(out.shape[0], 50)
|
||||
self.assertEqual(out.shape[1], 2)
|
||||
|
||||
def testTfAssert(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@def_function.function(jit_compile=True)
|
||||
def f(x):
|
||||
control_flow_ops.Assert(x == 1, ['Wrong value'])
|
||||
|
||||
f(constant_op.constant(1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user