Move must-compilation from experimental_compile=True to a separate attribute
Previously, must-compilation semantics resulting from experimental_compile=True used the attribute _XlaCompile, which is the same one as used by auto-clustering to mark nodes to compile. This can result in a number of unfortunate collisions. PiperOrigin-RevId: 286678704 Change-Id: If3b58f18eb4116ae81ef266ebc6bdc8cab600e97
This commit is contained in:
parent
801b09624f
commit
375654d0ff
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
const char* const kXlaMustCompileAttr = "_XlaMustCompile";
|
||||||
|
|
||||||
const char* const kXlaCompileAttr = "_XlaCompile";
|
const char* const kXlaCompileAttr = "_XlaCompile";
|
||||||
|
|
||||||
// User-provided through jit_scope APIs. Effective only when auto_jit is OFF.
|
// User-provided through jit_scope APIs. Effective only when auto_jit is OFF.
|
||||||
|
|
|
@ -22,7 +22,16 @@ limitations under the License.
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Name of attribute used to tag operators for compilation with XLA
|
// Name of attribute used to tag operators for compilation with XLA
|
||||||
|
|
||||||
|
// Implies must-compile semantics: either it will be compiled
|
||||||
|
// with XLA, or an error will be thrown.
|
||||||
|
extern const char* const kXlaMustCompileAttr; // "_XlaMustCompile"
|
||||||
|
|
||||||
|
// Implies auto-clustering: tagged nodes will be clustered and compiled with XLA
|
||||||
|
// on a best-effort basis.
|
||||||
extern const char* const kXlaCompileAttr; // "_XlaCompile"
|
extern const char* const kXlaCompileAttr; // "_XlaCompile"
|
||||||
|
|
||||||
|
// Implies auto-clustering within the given scope.
|
||||||
extern const char* const kXlaScopeAttr; // "_XlaScope"
|
extern const char* const kXlaScopeAttr; // "_XlaScope"
|
||||||
extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope"
|
extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope"
|
||||||
|
|
||||||
|
|
|
@ -95,7 +95,7 @@ AttrValue BoolAttr(bool b) {
|
||||||
|
|
||||||
TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
|
TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
|
||||||
FunctionDef fdef = XTimesY();
|
FunctionDef fdef = XTimesY();
|
||||||
(*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true);
|
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
|
||||||
Init({fdef});
|
Init({fdef});
|
||||||
XlaKernelCreator xla_kernel_creator;
|
XlaKernelCreator xla_kernel_creator;
|
||||||
|
|
||||||
|
@ -137,7 +137,7 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
|
||||||
|
|
||||||
TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
|
TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
|
||||||
FunctionDef fdef = XTimesY();
|
FunctionDef fdef = XTimesY();
|
||||||
(*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false);
|
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(false);
|
||||||
Init({fdef});
|
Init({fdef});
|
||||||
XlaKernelCreator xla_kernel_creator;
|
XlaKernelCreator xla_kernel_creator;
|
||||||
|
|
||||||
|
|
|
@ -79,21 +79,21 @@ bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If kXlaCompileAttr is set on the node_def, use its value.
|
// If kXlaMustCompileAttr is set on the node_def, use its value.
|
||||||
const auto& it = node_def.attr().find(kXlaCompileAttr);
|
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
|
||||||
if (it != node_def.attr().end()) {
|
if (it != node_def.attr().end()) {
|
||||||
return it->second.b();
|
return it->second.b();
|
||||||
}
|
}
|
||||||
|
|
||||||
// kXlaCompileAttr is not set on node_def, check if it is set on
|
// kXlaMustCompileAttr is not set on node_def, check if it is set on
|
||||||
// FunctionDef.
|
// FunctionDef.
|
||||||
bool xla_compile = false;
|
bool xla_compile = false;
|
||||||
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
|
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
|
||||||
node_def, kXlaCompileAttr, &xla_compile);
|
node_def, kXlaMustCompileAttr, &xla_compile);
|
||||||
if (!status.ok() || !xla_compile) {
|
if (!status.ok() || !xla_compile) {
|
||||||
if (VLOG_IS_ON(3)) {
|
if (VLOG_IS_ON(3)) {
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
|
VLOG(3) << "No " << kXlaMustCompileAttr << " attr defined for "
|
||||||
<< node_def.op() << ". status=" << status.ToString();
|
<< node_def.op() << ". status=" << status.ToString();
|
||||||
} else {
|
} else {
|
||||||
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
|
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
|
||||||
|
|
|
@ -352,8 +352,8 @@ void AppendTensorShapeToFingerprint(const PartialTensorShape& shape,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ShouldCompileWithXLA(const EagerOperation* op, const EagerContext* ctx,
|
Status MustCompileWithXLA(const EagerOperation* op, const EagerContext* ctx,
|
||||||
bool* compile_with_xla) {
|
bool* compile_with_xla) {
|
||||||
if (!op->is_function()) {
|
if (!op->is_function()) {
|
||||||
*compile_with_xla = false;
|
*compile_with_xla = false;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -368,7 +368,7 @@ Status ShouldCompileWithXLA(const EagerOperation* op, const EagerContext* ctx,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Does node have an explicit request to compile or not?
|
// Does node have an explicit request to compile or not?
|
||||||
Status status = op->Attrs().Get(kXlaCompileAttr, compile_with_xla);
|
Status status = op->Attrs().Get(kXlaMustCompileAttr, compile_with_xla);
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
DVLOG(2) << "Caller explicitly requested "
|
DVLOG(2) << "Caller explicitly requested "
|
||||||
<< (*compile_with_xla ? "" : "not ")
|
<< (*compile_with_xla ? "" : "not ")
|
||||||
|
@ -383,7 +383,7 @@ Status ShouldCompileWithXLA(const EagerOperation* op, const EagerContext* ctx,
|
||||||
return errors::NotFound("Failed to find function '", op->Name(), "'");
|
return errors::NotFound("Failed to find function '", op->Name(), "'");
|
||||||
}
|
}
|
||||||
|
|
||||||
status = GetNodeAttr(AttrSlice(&function_def->attr()), kXlaCompileAttr,
|
status = GetNodeAttr(AttrSlice(&function_def->attr()), kXlaMustCompileAttr,
|
||||||
compile_with_xla);
|
compile_with_xla);
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
DVLOG(2) << "Function definition explicitly specifies "
|
DVLOG(2) << "Function definition explicitly specifies "
|
||||||
|
@ -511,12 +511,12 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||||
bool run_function_with_flr = false;
|
bool run_function_with_flr = false;
|
||||||
if (op->is_function()) {
|
if (op->is_function()) {
|
||||||
bool compile_with_xla;
|
bool compile_with_xla;
|
||||||
TF_RETURN_IF_ERROR(ShouldCompileWithXLA(op, ctx, &compile_with_xla));
|
TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla));
|
||||||
if (compile_with_xla) {
|
if (compile_with_xla) {
|
||||||
// Note that it is not ideal, but currently correct, to set this
|
// Note that it is not ideal, but currently correct, to set this
|
||||||
// attribute after computing the kernel cache key above.
|
// attribute after computing the kernel cache key above.
|
||||||
// Note: If the attribute is already set to true, this is a noop.
|
// Note: If the attribute is already set to true, this is a noop.
|
||||||
op->MutableAttrs()->Set(kXlaCompileAttr, true);
|
op->MutableAttrs()->Set(kXlaMustCompileAttr, true);
|
||||||
} else {
|
} else {
|
||||||
run_function_with_flr = true;
|
run_function_with_flr = true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -449,7 +449,7 @@ class Function(object):
|
||||||
if self._implements is not None:
|
if self._implements is not None:
|
||||||
attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements
|
attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements
|
||||||
if self._experimental_compile is not None:
|
if self._experimental_compile is not None:
|
||||||
attributes.update(_XlaCompile=bool(self._experimental_compile))
|
attributes.update(_XlaMustCompile=bool(self._experimental_compile))
|
||||||
if not attributes:
|
if not attributes:
|
||||||
attributes = None
|
attributes = None
|
||||||
return function_lib.defun_with_attributes(
|
return function_lib.defun_with_attributes(
|
||||||
|
|
|
@ -45,6 +45,31 @@ class DefFunctionTest(test.TestCase):
|
||||||
# XLA support is not yet enabled for TF ROCm
|
# XLA support is not yet enabled for TF ROCm
|
||||||
self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1))
|
self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1))
|
||||||
|
|
||||||
|
def testDerivative(self):
|
||||||
|
if test.is_built_with_rocm():
|
||||||
|
return
|
||||||
|
|
||||||
|
def fn(x, a):
|
||||||
|
return 2 * x + a
|
||||||
|
|
||||||
|
xla_func = def_function.function(fn, experimental_compile=True)
|
||||||
|
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
inputs = constant_op.constant([1., 2., 2., 3., 3.])
|
||||||
|
tape.watch(inputs)
|
||||||
|
outputs = xla_func(inputs, 1)
|
||||||
|
|
||||||
|
self.assertAllClose([2, 2, 2, 2, 2], tape.gradient(outputs, inputs))
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
(forward, backward) = xla_func.get_concrete_function(
|
||||||
|
inputs, 1)._delayed_rewrite_functions.forward_backward()
|
||||||
|
|
||||||
|
# Check that the must-compile attribute gets correctly propagated to the
|
||||||
|
# created derivatives.
|
||||||
|
self.assertTrue(backward.function_def.attr['_XlaMustCompile'])
|
||||||
|
self.assertTrue(forward.definition.attr['_XlaMustCompile'])
|
||||||
|
|
||||||
def testUnsupportedOps(self):
|
def testUnsupportedOps(self):
|
||||||
|
|
||||||
def fn(x):
|
def fn(x):
|
||||||
|
|
Loading…
Reference in New Issue