Move must-compilation from experimental_compile=True to a separate attribute

Previously, must-compilation semantics resulting from experimental_compile=True
used the attribute _XlaCompile, which is the same one as used by
auto-clustering to mark nodes to compile. This can result in a number of
unfortunate collisions.

PiperOrigin-RevId: 286678704
Change-Id: If3b58f18eb4116ae81ef266ebc6bdc8cab600e97
This commit is contained in:
George Karpenkov 2019-12-20 20:28:03 -08:00 committed by TensorFlower Gardener
parent 801b09624f
commit 375654d0ff
7 changed files with 50 additions and 14 deletions

View File

@ -17,6 +17,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
const char* const kXlaMustCompileAttr = "_XlaMustCompile";
const char* const kXlaCompileAttr = "_XlaCompile"; const char* const kXlaCompileAttr = "_XlaCompile";
// User-provided through jit_scope APIs. Effective only when auto_jit is OFF. // User-provided through jit_scope APIs. Effective only when auto_jit is OFF.

View File

@ -22,7 +22,16 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
// Name of attribute used to tag operators for compilation with XLA // Name of attribute used to tag operators for compilation with XLA
// Implies must-compile semantics: either it will be compiled
// with XLA, or an error will be thrown.
extern const char* const kXlaMustCompileAttr; // "_XlaMustCompile"
// Implies auto-clustering: tagged nodes will be clustered and compiled with XLA
// on a best-effort basis.
extern const char* const kXlaCompileAttr; // "_XlaCompile" extern const char* const kXlaCompileAttr; // "_XlaCompile"
// Implies auto-clustering within the given scope.
extern const char* const kXlaScopeAttr; // "_XlaScope" extern const char* const kXlaScopeAttr; // "_XlaScope"
extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope" extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope"

View File

@ -95,7 +95,7 @@ AttrValue BoolAttr(bool b) {
TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) { TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
FunctionDef fdef = XTimesY(); FunctionDef fdef = XTimesY();
(*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true); (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
Init({fdef}); Init({fdef});
XlaKernelCreator xla_kernel_creator; XlaKernelCreator xla_kernel_creator;
@ -137,7 +137,7 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) { TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
FunctionDef fdef = XTimesY(); FunctionDef fdef = XTimesY();
(*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false); (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(false);
Init({fdef}); Init({fdef});
XlaKernelCreator xla_kernel_creator; XlaKernelCreator xla_kernel_creator;

View File

@ -79,21 +79,21 @@ bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
return false; return false;
} }
// If kXlaCompileAttr is set on the node_def, use its value. // If kXlaMustCompileAttr is set on the node_def, use its value.
const auto& it = node_def.attr().find(kXlaCompileAttr); const auto& it = node_def.attr().find(kXlaMustCompileAttr);
if (it != node_def.attr().end()) { if (it != node_def.attr().end()) {
return it->second.b(); return it->second.b();
} }
// kXlaCompileAttr is not set on node_def, check if it is set on // kXlaMustCompileAttr is not set on node_def, check if it is set on
// FunctionDef. // FunctionDef.
bool xla_compile = false; bool xla_compile = false;
Status status = flr.GetFunctionLibraryDefinition()->GetAttr( Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
node_def, kXlaCompileAttr, &xla_compile); node_def, kXlaMustCompileAttr, &xla_compile);
if (!status.ok() || !xla_compile) { if (!status.ok() || !xla_compile) {
if (VLOG_IS_ON(3)) { if (VLOG_IS_ON(3)) {
if (!status.ok()) { if (!status.ok()) {
VLOG(3) << "No " << kXlaCompileAttr << " attr defined for " VLOG(3) << "No " << kXlaMustCompileAttr << " attr defined for "
<< node_def.op() << ". status=" << status.ToString(); << node_def.op() << ". status=" << status.ToString();
} else { } else {
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";

View File

@ -352,8 +352,8 @@ void AppendTensorShapeToFingerprint(const PartialTensorShape& shape,
} }
} }
Status ShouldCompileWithXLA(const EagerOperation* op, const EagerContext* ctx, Status MustCompileWithXLA(const EagerOperation* op, const EagerContext* ctx,
bool* compile_with_xla) { bool* compile_with_xla) {
if (!op->is_function()) { if (!op->is_function()) {
*compile_with_xla = false; *compile_with_xla = false;
return Status::OK(); return Status::OK();
@ -368,7 +368,7 @@ Status ShouldCompileWithXLA(const EagerOperation* op, const EagerContext* ctx,
} }
// Does node have an explicit request to compile or not? // Does node have an explicit request to compile or not?
Status status = op->Attrs().Get(kXlaCompileAttr, compile_with_xla); Status status = op->Attrs().Get(kXlaMustCompileAttr, compile_with_xla);
if (status.ok()) { if (status.ok()) {
DVLOG(2) << "Caller explicitly requested " DVLOG(2) << "Caller explicitly requested "
<< (*compile_with_xla ? "" : "not ") << (*compile_with_xla ? "" : "not ")
@ -383,7 +383,7 @@ Status ShouldCompileWithXLA(const EagerOperation* op, const EagerContext* ctx,
return errors::NotFound("Failed to find function '", op->Name(), "'"); return errors::NotFound("Failed to find function '", op->Name(), "'");
} }
status = GetNodeAttr(AttrSlice(&function_def->attr()), kXlaCompileAttr, status = GetNodeAttr(AttrSlice(&function_def->attr()), kXlaMustCompileAttr,
compile_with_xla); compile_with_xla);
if (status.ok()) { if (status.ok()) {
DVLOG(2) << "Function definition explicitly specifies " DVLOG(2) << "Function definition explicitly specifies "
@ -511,12 +511,12 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
bool run_function_with_flr = false; bool run_function_with_flr = false;
if (op->is_function()) { if (op->is_function()) {
bool compile_with_xla; bool compile_with_xla;
TF_RETURN_IF_ERROR(ShouldCompileWithXLA(op, ctx, &compile_with_xla)); TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla));
if (compile_with_xla) { if (compile_with_xla) {
// Note that it is not ideal, but currently correct, to set this // Note that it is not ideal, but currently correct, to set this
// attribute after computing the kernel cache key above. // attribute after computing the kernel cache key above.
// Note: If the attribute is already set to true, this is a noop. // Note: If the attribute is already set to true, this is a noop.
op->MutableAttrs()->Set(kXlaCompileAttr, true); op->MutableAttrs()->Set(kXlaMustCompileAttr, true);
} else { } else {
run_function_with_flr = true; run_function_with_flr = true;
} }

View File

@ -449,7 +449,7 @@ class Function(object):
if self._implements is not None: if self._implements is not None:
attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements
if self._experimental_compile is not None: if self._experimental_compile is not None:
attributes.update(_XlaCompile=bool(self._experimental_compile)) attributes.update(_XlaMustCompile=bool(self._experimental_compile))
if not attributes: if not attributes:
attributes = None attributes = None
return function_lib.defun_with_attributes( return function_lib.defun_with_attributes(

View File

@ -45,6 +45,31 @@ class DefFunctionTest(test.TestCase):
# XLA support is not yet enabled for TF ROCm # XLA support is not yet enabled for TF ROCm
self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1)) self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1))
def testDerivative(self):
if test.is_built_with_rocm():
return
def fn(x, a):
return 2 * x + a
xla_func = def_function.function(fn, experimental_compile=True)
with backprop.GradientTape() as tape:
inputs = constant_op.constant([1., 2., 2., 3., 3.])
tape.watch(inputs)
outputs = xla_func(inputs, 1)
self.assertAllClose([2, 2, 2, 2, 2], tape.gradient(outputs, inputs))
# pylint: disable=protected-access
(forward, backward) = xla_func.get_concrete_function(
inputs, 1)._delayed_rewrite_functions.forward_backward()
# Check that the must-compile attribute gets correctly propagated to the
# created derivatives.
self.assertTrue(backward.function_def.attr['_XlaMustCompile'])
self.assertTrue(forward.definition.attr['_XlaMustCompile'])
def testUnsupportedOps(self): def testUnsupportedOps(self):
def fn(x): def fn(x):