diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc index 61c89d8a67a..3cc68f2a1a4 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc @@ -80,35 +80,31 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, // Make sure that kernels have been registered on the JIT device. XlaOpRegistry::RegisterCompilationKernels(); - - // Only check for compilability if the MLIR bridge is not enabled. - if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { - RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; - if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { - std::vector - uncompilable_node_info; - for (const auto& it : uncompilable_nodes_map) { - for (const auto& info : it.second.second) { - uncompilable_node_info.emplace_back(info); - } + RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; + if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { + std::vector + uncompilable_node_info; + for (const auto& it : uncompilable_nodes_map) { + for (const auto& info : it.second.second) { + uncompilable_node_info.emplace_back(info); } - string message = absl::StrCat( - "Function invoked by the following node is not compilable: ", - SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); - absl::StrAppend(&message, "Uncompilable nodes:"); - for (const auto& node_info : uncompilable_node_info) { - string node_message = absl::StrCat("\n", node_info.name, ": ", - node_info.uncompilable_reason, "\n", - "\tStacktrace:\n"); - for (const auto& stack_frame : node_info.stack_trace) { - absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n", - stack_frame.name, stack_frame.function_name); - } - absl::StrAppend(&message, node_message); - } - VLOG(1) << message; - return errors::InvalidArgument(message); } + string message = absl::StrCat( + "Function invoked by the following node is not compilable: ", + SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); + absl::StrAppend(&message, "Uncompilable nodes:"); + for (const auto& node_info : uncompilable_node_info) { + string node_message = + absl::StrCat("\n", node_info.name, ": ", + node_info.uncompilable_reason, "\n", "\tStacktrace:\n"); + for (const auto& stack_frame : node_info.stack_trace) { + absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n", + stack_frame.name, stack_frame.function_name); + } + absl::StrAppend(&message, node_message); + } + VLOG(1) << message; + return errors::InvalidArgument(message); } // Get function body, constant args, and resource args. diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index c2b5000647d..d9450cb6364 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -123,6 +123,7 @@ tf_xla_py_test( name = "adagrad_da_test", size = "small", srcs = ["adagrad_da_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -160,6 +161,7 @@ tf_xla_py_test( srcs = ["add_n_test.py"], # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -687,6 +689,7 @@ tf_xla_py_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 6, tags = [ @@ -926,6 +929,7 @@ tf_xla_py_test( name = "pooling_ops_test", size = "medium", srcs = ["pooling_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 20, tags = [ @@ -1239,6 +1243,7 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "config-cuda-only", @@ -1299,6 +1304,7 @@ tf_xla_py_test( srcs = ["tensor_array_ops_test.py"], # TensorArray ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "config-cuda-only", @@ -1327,6 +1333,7 @@ tf_xla_py_test( srcs = ["tensor_list_ops_test.py"], # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1882,6 +1889,7 @@ tf_xla_py_test( name = "special_math_test", size = "medium", srcs = ["special_math_test.py"], + enable_mlir_bridge = True, shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 0045a7958b4..db54f2f6563 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -53,7 +52,6 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -728,18 +726,8 @@ Status XlaCompiler::CompileFunction( } VLOG(1) << "===================================================="; - if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { - VLOG(1) << "Using MLIR bridge"; - GraphDebugInfo debug_info; - TF_RETURN_IF_ERROR(CompileGraphToXlaHlo( - std::move(*graph), {args.data(), args.size()}, - options_.device_type.type_string(), options.use_tuple_arg, - *options_.flib_def, debug_info, options_.shape_representation_fn, - result)); - } else { - TF_RETURN_IF_ERROR( - CompileGraph(options, function_id, std::move(graph), args, result)); - } + TF_RETURN_IF_ERROR( + CompileGraph(options, function_id, std::move(graph), args, result)); VLOG(1) << "===================================================="; cache_[{function_id, arg_vector}] = *result;