[TF2XLA] Enable using MLIR bridge when TF_XLA_FLAGS=--tf_mlir_enable_mlir_bridge is on for tf.function(compile=True)

PiperOrigin-RevId: 323707882
Change-Id: I34a513fad8a5119b8a68180fc7277ff80fc6a555
This commit is contained in:
A. Unique TensorFlower 2020-07-28 20:09:13 -07:00 committed by TensorFlower Gardener
parent 37c793e9b8
commit 4353b9cd4d
3 changed files with 33 additions and 41 deletions

View File

@ -80,9 +80,6 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
// Make sure that kernels have been registered on the JIT device. // Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels(); XlaOpRegistry::RegisterCompilationKernels();
// Only check for compilability if the MLIR bridge is not enabled.
if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo> std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
@ -97,9 +94,9 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:"); absl::StrAppend(&message, "Uncompilable nodes:");
for (const auto& node_info : uncompilable_node_info) { for (const auto& node_info : uncompilable_node_info) {
string node_message = absl::StrCat("\n", node_info.name, ": ", string node_message =
node_info.uncompilable_reason, "\n", absl::StrCat("\n", node_info.name, ": ",
"\tStacktrace:\n"); node_info.uncompilable_reason, "\n", "\tStacktrace:\n");
for (const auto& stack_frame : node_info.stack_trace) { for (const auto& stack_frame : node_info.stack_trace) {
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n", absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
stack_frame.name, stack_frame.function_name); stack_frame.name, stack_frame.function_name);
@ -109,7 +106,6 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
VLOG(1) << message; VLOG(1) << message;
return errors::InvalidArgument(message); return errors::InvalidArgument(message);
} }
}
// Get function body, constant args, and resource args. // Get function body, constant args, and resource args.
const FunctionBody* fbody = nullptr; const FunctionBody* fbody = nullptr;

View File

@ -123,6 +123,7 @@ tf_xla_py_test(
name = "adagrad_da_test", name = "adagrad_da_test",
size = "small", size = "small",
srcs = ["adagrad_da_test.py"], srcs = ["adagrad_da_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -160,6 +161,7 @@ tf_xla_py_test(
srcs = ["add_n_test.py"], srcs = ["add_n_test.py"],
# TensorList ops are not implemented in the on-demand compilation model yet. # TensorList ops are not implemented in the on-demand compilation model yet.
disabled_backends = ["cpu_ondemand"], disabled_backends = ["cpu_ondemand"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -687,6 +689,7 @@ tf_xla_py_test(
name = "fft_test", name = "fft_test",
size = "medium", size = "medium",
srcs = ["fft_test.py"], srcs = ["fft_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
shard_count = 6, shard_count = 6,
tags = [ tags = [
@ -926,6 +929,7 @@ tf_xla_py_test(
name = "pooling_ops_test", name = "pooling_ops_test",
size = "medium", size = "medium",
srcs = ["pooling_ops_test.py"], srcs = ["pooling_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
shard_count = 20, shard_count = 20,
tags = [ tags = [
@ -1239,6 +1243,7 @@ tf_xla_py_test(
name = "stack_ops_test", name = "stack_ops_test",
size = "small", size = "small",
srcs = ["stack_ops_test.py"], srcs = ["stack_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"config-cuda-only", "config-cuda-only",
@ -1299,6 +1304,7 @@ tf_xla_py_test(
srcs = ["tensor_array_ops_test.py"], srcs = ["tensor_array_ops_test.py"],
# TensorArray ops are not implemented in the on-demand compilation model yet. # TensorArray ops are not implemented in the on-demand compilation model yet.
disabled_backends = ["cpu_ondemand"], disabled_backends = ["cpu_ondemand"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"config-cuda-only", "config-cuda-only",
@ -1327,6 +1333,7 @@ tf_xla_py_test(
srcs = ["tensor_list_ops_test.py"], srcs = ["tensor_list_ops_test.py"],
# TensorList ops are not implemented in the on-demand compilation model yet. # TensorList ops are not implemented in the on-demand compilation model yet.
disabled_backends = ["cpu_ondemand"], disabled_backends = ["cpu_ondemand"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -1882,6 +1889,7 @@ tf_xla_py_test(
name = "special_math_test", name = "special_math_test",
size = "medium", size = "medium",
srcs = ["special_math_test.py"], srcs = ["special_math_test.py"],
enable_mlir_bridge = True,
shard_count = 5, shard_count = 5,
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/jit/shape_inference.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
@ -53,7 +52,6 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/util/dump_graph.h" #include "tensorflow/core/util/dump_graph.h"
namespace tensorflow { namespace tensorflow {
@ -728,18 +726,8 @@ Status XlaCompiler::CompileFunction(
} }
VLOG(1) << "===================================================="; VLOG(1) << "====================================================";
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
std::move(*graph), {args.data(), args.size()},
options_.device_type.type_string(), options.use_tuple_arg,
*options_.flib_def, debug_info, options_.shape_representation_fn,
result));
} else {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
CompileGraph(options, function_id, std::move(graph), args, result)); CompileGraph(options, function_id, std::move(graph), args, result));
}
VLOG(1) << "===================================================="; VLOG(1) << "====================================================";
cache_[{function_id, arg_vector}] = *result; cache_[{function_id, arg_vector}] = *result;