diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc index 3cc68f2a1a4..61c89d8a67a 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc @@ -80,31 +80,35 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, // Make sure that kernels have been registered on the JIT device. XlaOpRegistry::RegisterCompilationKernels(); - RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; - if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { - std::vector - uncompilable_node_info; - for (const auto& it : uncompilable_nodes_map) { - for (const auto& info : it.second.second) { - uncompilable_node_info.emplace_back(info); + + // Only check for compilability if the MLIR bridge is not enabled. + if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { + RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; + if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { + std::vector + uncompilable_node_info; + for (const auto& it : uncompilable_nodes_map) { + for (const auto& info : it.second.second) { + uncompilable_node_info.emplace_back(info); + } } - } - string message = absl::StrCat( - "Function invoked by the following node is not compilable: ", - SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); - absl::StrAppend(&message, "Uncompilable nodes:"); - for (const auto& node_info : uncompilable_node_info) { - string node_message = - absl::StrCat("\n", node_info.name, ": ", - node_info.uncompilable_reason, "\n", "\tStacktrace:\n"); - for (const auto& stack_frame : node_info.stack_trace) { - absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n", - stack_frame.name, stack_frame.function_name); + string message = absl::StrCat( + "Function invoked by the following node is not compilable: ", + SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); + absl::StrAppend(&message, "Uncompilable nodes:"); + for (const auto& node_info : uncompilable_node_info) { + string node_message = absl::StrCat("\n", node_info.name, ": ", + node_info.uncompilable_reason, "\n", + "\tStacktrace:\n"); + for (const auto& stack_frame : node_info.stack_trace) { + absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n", + stack_frame.name, stack_frame.function_name); + } + absl::StrAppend(&message, node_message); } - absl::StrAppend(&message, node_message); + VLOG(1) << message; + return errors::InvalidArgument(message); } - VLOG(1) << message; - return errors::InvalidArgument(message); } // Get function body, constant args, and resource args. diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index d9450cb6364..c2b5000647d 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -123,7 +123,6 @@ tf_xla_py_test( name = "adagrad_da_test", size = "small", srcs = ["adagrad_da_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -161,7 +160,6 @@ tf_xla_py_test( srcs = ["add_n_test.py"], # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -689,7 +687,6 @@ tf_xla_py_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 6, tags = [ @@ -929,7 +926,6 @@ tf_xla_py_test( name = "pooling_ops_test", size = "medium", srcs = ["pooling_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 20, tags = [ @@ -1243,7 +1239,6 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "config-cuda-only", @@ -1304,7 +1299,6 @@ tf_xla_py_test( srcs = ["tensor_array_ops_test.py"], # TensorArray ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "config-cuda-only", @@ -1333,7 +1327,6 @@ tf_xla_py_test( srcs = ["tensor_list_ops_test.py"], # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1889,7 +1882,6 @@ tf_xla_py_test( name = "special_math_test", size = "medium", srcs = ["special_math_test.py"], - enable_mlir_bridge = True, shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index db54f2f6563..0045a7958b4 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -52,6 +53,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -726,8 +728,18 @@ Status XlaCompiler::CompileFunction( } VLOG(1) << "===================================================="; - TF_RETURN_IF_ERROR( - CompileGraph(options, function_id, std::move(graph), args, result)); + if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { + VLOG(1) << "Using MLIR bridge"; + GraphDebugInfo debug_info; + TF_RETURN_IF_ERROR(CompileGraphToXlaHlo( + std::move(*graph), {args.data(), args.size()}, + options_.device_type.type_string(), options.use_tuple_arg, + *options_.flib_def, debug_info, options_.shape_representation_fn, + result)); + } else { + TF_RETURN_IF_ERROR( + CompileGraph(options, function_id, std::move(graph), args, result)); + } VLOG(1) << "===================================================="; cache_[{function_id, arg_vector}] = *result;