Fix experimental_compile=True for graph mode

Previously the attribute only worked in eager mode, and was a no-op otherwise.
Note that this also resolves the problem of a function with
experimental_compile=True not being compiled when called from
experimental_compile=False context.
PiperOrigin-RevId: 286682281
Change-Id: Ifbc6efa2c82ae13f5d124ec6aaf440e1639a42c3
This commit is contained in:
George Karpenkov 2019-12-20 21:21:33 -08:00 committed by TensorFlower Gardener
parent 8a7097eb9b
commit f3c0a3e0e4
11 changed files with 190 additions and 62 deletions

View File

@ -21,7 +21,7 @@ namespace tensorflow {
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) const {
return CanCreateXlaKernel(flr, node_def);
return CanCreateXlaKernel(node_def);
}
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,

View File

@ -98,12 +98,14 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
Init({fdef});
XlaKernelCreator xla_kernel_creator;
Status status = xla_kernel_creator.CreateKernel(
flr_, ToNodeDef(R"pb(
NodeDef callsite =
ToNodeDef(R"pb(
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
)pb"),
&kernel_);
)pb");
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
// Note: need to set attribute on the created node.
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
ASSERT_TRUE(status.ok()) << status.ToString();
EXPECT_EQ("XTimesY", kernel_->name());

View File

@ -23,7 +23,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
@ -68,40 +70,10 @@ class SinglePassSearch {
};
} // namespace
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) {
const FunctionDef* function_def =
flr.GetFunctionLibraryDefinition()->Find(node_def.name());
if (function_def == nullptr) {
// The node def is not calling a function. Individual ops can be
// run directly using on-demand mode, no need to create XlaLaunch
// kernel for them.
return false;
}
bool CanCreateXlaKernel(const NodeDef& node_def) {
// If kXlaMustCompileAttr is set on the node_def, use its value.
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
if (it != node_def.attr().end()) {
return it->second.b();
}
// kXlaMustCompileAttr is not set on node_def, check if it is set on
// FunctionDef.
bool xla_compile = false;
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
node_def, kXlaMustCompileAttr, &xla_compile);
if (!status.ok() || !xla_compile) {
if (VLOG_IS_ON(3)) {
if (!status.ok()) {
VLOG(3) << "No " << kXlaMustCompileAttr << " attr defined for "
<< node_def.op() << ". status=" << status.ToString();
} else {
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
}
}
return false;
}
return true;
return it != node_def.attr().end() && it->second.b();
}
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
@ -118,8 +90,11 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
FunctionLibraryRuntime::Handle handle;
// If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
TF_RETURN_IF_ERROR(
flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
*fbody = flr->GetFunctionBody(handle);
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
const DataTypeVector& arg_types = (*fbody)->arg_types;
@ -149,7 +124,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) {
if (!CanCreateXlaKernel(*flr, node_def)) {
if (!CanCreateXlaKernel(node_def)) {
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
}
@ -241,9 +216,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
// Create the kernel.
NameAttrList function;
function.set_name(node_def.op());
*(function.mutable_attr()) = node_def.attr();
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(

View File

@ -24,11 +24,9 @@ namespace tensorflow {
class FunctionLibraryRuntime;
class OpKernel;
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
// true if 'node_def' is a call to a compilable function defined in 'flr',
// with the kXlaCompileAttr set.
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def);
// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
// set.
bool CanCreateXlaKernel(const NodeDef& node_def);
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,

View File

@ -2533,6 +2533,7 @@ tf_cuda_library(
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/util:port",
"//tensorflow/core/util:stats_calculator_portable",
"//tensorflow/compiler/jit:common",
] + if_static(
extra_deps = ["@com_google_protobuf//:protobuf"],
otherwise = ["@com_google_protobuf//:protobuf_headers"],

View File

@ -1313,10 +1313,12 @@ Status DirectSession::CreateExecutors(
options_.config.experimental().has_session_metadata()
? &options_.config.experimental().session_metadata()
: nullptr;
const CustomKernelCreator* custom_kernel_creator =
GetDefaultCustomKernelCreator();
func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), options_.env, &options_.config, graph_def_version,
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
nullptr, nullptr, session_metadata));
nullptr, custom_kernel_creator, session_metadata));
GraphOptimizer optimizer(optimizer_opts);
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -97,6 +99,11 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
inp_mtypes->clear();
out_mtypes->clear();
bool has_xla_compile = [&] {
const auto& it = ndef.attr().find(kXlaMustCompileAttr);
return it != ndef.attr().end() && it->second.b();
}();
// For functions (which have no KernelDef) and their gradients, we can only
// best-effort derive the memory type from the data type. For now, we assume
// int32 is always on host memory and other types are always on device memory.
@ -104,7 +111,7 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
// to derive the correct input/output memory types. We should also split
// host-memory and non host-memory arguments into separate type lists.
if (!status.ok() || IsFunctionCallOp(ndef.op())) {
if (device_type.type_string() == "TPU") {
if (device_type.type_string() == "TPU" || has_xla_compile) {
// Here we assume that if tf.function() is called within
// "with tf.device('/device:TPU:0')", the whole function will be compiled
// and executed on TPU. This is true today, but when we implement auto

View File

@ -91,3 +91,20 @@ cuda_py_test(
"@absl_py//absl/testing:parameterized",
],
)
cuda_py_test(
name = "experimental_compile_test",
srcs = ["experimental_compile_test.py"],
additional_deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
],
python_version = "PY3",
tags = [
"no_mac",
"no_windows",
],
xla_enabled = True,
)

View File

@ -0,0 +1,113 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ExperimentalCompileTest(test.TestCase):
def testBasic(self):
with ops.Graph().as_default() as g:
def fn(x, a):
return x + a
xla_func = def_function.function(fn, experimental_compile=True)
inputs = array_ops.placeholder(dtypes.float32, [5])
# XLA support is not yet enabled for TF ROCm
if not test.is_built_with_rocm():
x = xla_func(inputs, 1)
with session.Session(graph=g) as sess:
y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
self.assertTrue(x.graph.as_graph_def().library.function[0]
.attr["_XlaMustCompile"].b)
self.assertAllClose([2, 3, 3, 4, 4], y)
def testDerivative(self):
# XLA support is not yet enabled for TF ROCm
if test.is_built_with_rocm():
return
def fn(x, a):
return 2 * x + a
with ops.Graph().as_default() as g:
xla_func = def_function.function(fn, experimental_compile=True)
with backprop.GradientTape() as tape:
inputs = array_ops.placeholder(dtypes.float32, [5])
tape.watch(inputs)
outputs = xla_func(inputs, 1)
grads = tape.gradient(outputs, inputs)
with session.Session(graph=g) as sess:
grads_tensor = sess.run(grads, feed_dict={inputs: [1, 2, 2, 3, 3]})
self.assertAllClose([2, 2, 2, 2, 2], grads_tensor)
(forward, backward) = xla_func.get_concrete_function(
inputs, 1)._delayed_rewrite_functions.forward_backward()
# Check that the must-compile attribute gets correctly propagated to the
# created derivatives.
self.assertTrue(forward.definition.attr["_XlaMustCompile"])
self.assertTrue(backward.function_def.attr["_XlaMustCompile"])
def testBasicInt32(self):
with ops.Graph().as_default() as g:
def fn(x, a):
return x + a
xla_func = def_function.function(fn, experimental_compile=True)
inputs = array_ops.placeholder(dtypes.int32, [5])
# XLA support is not yet enabled for TF ROCm
if not test.is_built_with_rocm():
x = xla_func(inputs, 1)
with session.Session(graph=g) as sess:
y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
self.assertTrue(x.graph.as_graph_def().library.function[0]
.attr["_XlaMustCompile"].b)
self.assertAllClose([2, 3, 3, 4, 4], y)
# Checking that we crash on an unsupported operation lets us test that the XLA
# compiler was actually invoked.
def testUnsupportedOps(self):
with ops.Graph().as_default() as g:
def fn(x):
return array_ops.unique(x).y # Unique is not supported by XLA
xla_func = def_function.function(fn, experimental_compile=True)
inputs = array_ops.placeholder(dtypes.float32, [5])
x = xla_func(inputs)
# XLA support is not yet enabled for TF ROCm
if not test.is_built_with_rocm():
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"not compilable"):
with session.Session(graph=g) as sess:
sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
if __name__ == "__main__":
test.main()

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -45,6 +46,18 @@ class DefFunctionTest(test.TestCase):
# XLA support is not yet enabled for TF ROCm
self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1))
def testBasicInt32(self):
def fn(x, a):
return x + a
xla_func = def_function.function(fn, experimental_compile=True)
inputs = constant_op.constant([1, 2, 2, 3, 3], dtype=dtypes.int32)
if not test.is_built_with_rocm():
# XLA support is not yet enabled for TF ROCm
self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1))
def testDerivative(self):
if test.is_built_with_rocm():
return

View File

@ -1163,17 +1163,19 @@ def partitioned_call(args,
graph = ops.get_default_graph()
f.add_to_graph(graph)
op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall"
op = graph.create_op(
op_name,
args,
tout,
name=op_name,
attrs={
"Tin": tin_attr,
"Tout": tout_attr,
"f": func_attr,
"config_proto": config_proto,
"executor_type": executor_type_attr,
})
# Propagate the attribute indicating the need to compile from function to the
# call itself.
xla_compile_attr = "_XlaMustCompile"
op_attrs = {
"Tin": tin_attr,
"Tout": tout_attr,
"f": func_attr,
"config_proto": config_proto,
"executor_type": executor_type_attr,
}
if xla_compile_attr in f.definition.attr:
op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr]
op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs)
outputs = op.outputs
return outputs if outputs else op