Fix experimental_compile=True for graph mode
Previously the attribute only worked in eager mode, and was a no-op otherwise. Note that this also resolves the problem of a function with experimental_compile=True not being compiled when called from experimental_compile=False context. PiperOrigin-RevId: 285509461 Change-Id: I3f8d5611fea5b7430feba1c58f937e121d71b75c
This commit is contained in:
parent
d8369591e7
commit
0505a9e2cb
@ -23,9 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
|
||||||
@ -72,42 +70,38 @@ class SinglePassSearch {
|
|||||||
|
|
||||||
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
|
bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
|
||||||
const NodeDef& node_def) {
|
const NodeDef& node_def) {
|
||||||
VLOG(2) << "Called CanCreateXlaKernel, input: " << SummarizeNodeDef(node_def);
|
|
||||||
NameAttrList attr_list;
|
|
||||||
if (!NameAndAttrsFromFunctionCall(node_def, &attr_list).ok()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
std::string func_name = attr_list.name();
|
|
||||||
const FunctionDef* function_def =
|
const FunctionDef* function_def =
|
||||||
flr.GetFunctionLibraryDefinition()->Find(func_name);
|
flr.GetFunctionLibraryDefinition()->Find(node_def.name());
|
||||||
if (function_def == nullptr) {
|
if (function_def == nullptr) {
|
||||||
// The node def is not calling a function. Individual ops can be
|
// The node def is not calling a function. Individual ops can be
|
||||||
// run directly using on-demand mode, no need to create XlaLaunch
|
// run directly using on-demand mode, no need to create XlaLaunch
|
||||||
// kernel for them.
|
// kernel for them.
|
||||||
VLOG(2) << "Not creating XlaLaunch kernel for " << func_name
|
|
||||||
<< " because it does not seem to be a function";
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If kXlaCompileAttr is set on the node_def, use its value.
|
// If kXlaCompileAttr is set on the node_def, use its value.
|
||||||
const auto& it = node_def.attr().find(kXlaCompileAttr);
|
const auto& it = node_def.attr().find(kXlaCompileAttr);
|
||||||
if (it != node_def.attr().end()) {
|
if (it != node_def.attr().end()) {
|
||||||
bool value = it->second.b();
|
return it->second.b();
|
||||||
VLOG(2) << "Found " << kXlaCompileAttr
|
|
||||||
<< " attribute with value = " << value
|
|
||||||
<< " on node: " << SummarizeNodeDef(node_def);
|
|
||||||
return value;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, look for it on the custom defition.
|
// kXlaCompileAttr is not set on node_def, check if it is set on
|
||||||
const auto& fit = function_def->attr().find(kXlaCompileAttr);
|
// FunctionDef.
|
||||||
if (fit != function_def->attr().end()) {
|
bool xla_compile = false;
|
||||||
bool value = fit->second.b();
|
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
|
||||||
VLOG(2) << "Found " << kXlaCompileAttr << " attribute on function "
|
node_def, kXlaCompileAttr, &xla_compile);
|
||||||
<< func_name << " with value = " << value;
|
if (!status.ok() || !xla_compile) {
|
||||||
return value;
|
if (VLOG_IS_ON(3)) {
|
||||||
|
if (!status.ok()) {
|
||||||
|
VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
|
||||||
|
<< node_def.op() << ". status=" << status.ToString();
|
||||||
|
} else {
|
||||||
|
VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
||||||
@ -124,11 +118,8 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
|||||||
FunctionLibraryRuntime::Handle handle;
|
FunctionLibraryRuntime::Handle handle;
|
||||||
// If node_def is not instantiable, e.g., the function does not exist,
|
// If node_def is not instantiable, e.g., the function does not exist,
|
||||||
// simply bail out.
|
// simply bail out.
|
||||||
NameAttrList function;
|
|
||||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
|
flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
|
||||||
*fbody = flr->GetFunctionBody(handle);
|
*fbody = flr->GetFunctionBody(handle);
|
||||||
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
||||||
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
||||||
@ -250,7 +241,9 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
|||||||
|
|
||||||
// Create the kernel.
|
// Create the kernel.
|
||||||
NameAttrList function;
|
NameAttrList function;
|
||||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
function.set_name(node_def.op());
|
||||||
|
*(function.mutable_attr()) = node_def.attr();
|
||||||
|
|
||||||
Device* dev = flr->device();
|
Device* dev = flr->device();
|
||||||
Status s;
|
Status s;
|
||||||
OpKernelConstruction construction(
|
OpKernelConstruction construction(
|
||||||
|
@ -1302,12 +1302,10 @@ Status DirectSession::CreateExecutors(
|
|||||||
options_.config.experimental().has_session_metadata()
|
options_.config.experimental().has_session_metadata()
|
||||||
? &options_.config.experimental().session_metadata()
|
? &options_.config.experimental().session_metadata()
|
||||||
: nullptr;
|
: nullptr;
|
||||||
const CustomKernelCreator* custom_kernel_creator =
|
|
||||||
GetDefaultCustomKernelCreator();
|
|
||||||
func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
|
func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
|
||||||
device_mgr_.get(), options_.env, &options_.config, graph_def_version,
|
device_mgr_.get(), options_.env, &options_.config, graph_def_version,
|
||||||
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
|
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
|
||||||
nullptr, custom_kernel_creator, session_metadata));
|
nullptr, nullptr, session_metadata));
|
||||||
|
|
||||||
GraphOptimizer optimizer(optimizer_opts);
|
GraphOptimizer optimizer(optimizer_opts);
|
||||||
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
|
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
|
||||||
|
@ -91,20 +91,3 @@ cuda_py_test(
|
|||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
|
||||||
name = "experimental_compile_test",
|
|
||||||
srcs = ["experimental_compile_test.py"],
|
|
||||||
additional_deps = [
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python:constant_op",
|
|
||||||
"//tensorflow/python:framework_ops",
|
|
||||||
"//tensorflow/python:resource_variable_ops",
|
|
||||||
],
|
|
||||||
python_version = "PY3",
|
|
||||||
tags = [
|
|
||||||
"no_mac",
|
|
||||||
"no_windows",
|
|
||||||
],
|
|
||||||
xla_enabled = True,
|
|
||||||
)
|
|
||||||
|
@ -1,68 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.python.client import session
|
|
||||||
from tensorflow.python.eager import def_function
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.framework import errors
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.platform import test
|
|
||||||
|
|
||||||
|
|
||||||
class ExperimentalCompileTest(test.TestCase):
|
|
||||||
|
|
||||||
def testBasic(self):
|
|
||||||
with ops.Graph().as_default() as g:
|
|
||||||
|
|
||||||
def fn(x, a):
|
|
||||||
return x + a
|
|
||||||
|
|
||||||
xla_func = def_function.function(fn, experimental_compile=True)
|
|
||||||
inputs = array_ops.placeholder(dtypes.float32, [5])
|
|
||||||
# XLA support is not yet enabled for TF ROCm
|
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
x = xla_func(inputs, 1)
|
|
||||||
with session.Session(graph=g) as sess:
|
|
||||||
y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
|
|
||||||
self.assertTrue(
|
|
||||||
x.graph.as_graph_def().library.function[0].attr["_XlaCompile"].b)
|
|
||||||
self.assertAllClose([2, 3, 3, 4, 4], y)
|
|
||||||
|
|
||||||
# Checking that we crash on an unsupported operation lets us test that the XLA
|
|
||||||
# compiler was actually invoked.
|
|
||||||
def testUnsupportedOps(self):
|
|
||||||
with ops.Graph().as_default() as g:
|
|
||||||
|
|
||||||
def fn(x):
|
|
||||||
return array_ops.unique(x).y # Unique is not supported by XLA
|
|
||||||
|
|
||||||
xla_func = def_function.function(fn, experimental_compile=True)
|
|
||||||
inputs = array_ops.placeholder(dtypes.float32, [5])
|
|
||||||
x = xla_func(inputs)
|
|
||||||
# XLA support is not yet enabled for TF ROCm
|
|
||||||
if not test.is_built_with_rocm():
|
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
|
||||||
"not compilable"):
|
|
||||||
with session.Session(graph=g) as sess:
|
|
||||||
sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test.main()
|
|
Loading…
Reference in New Issue
Block a user