[TF2XLA] Enable using MLIR bridge when TF_XLA_FLAGS=--tf_mlir_enable_mlir_bridge is on for tf.function(compile=True)
PiperOrigin-RevId: 323683301 Change-Id: Ib1cfaec1bd27c3bf691820c616cdca1721aabe25
This commit is contained in:
parent
9dc7dc2468
commit
42a9b7f7ae
@ -80,31 +80,35 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
|||||||
|
|
||||||
// Make sure that kernels have been registered on the JIT device.
|
// Make sure that kernels have been registered on the JIT device.
|
||||||
XlaOpRegistry::RegisterCompilationKernels();
|
XlaOpRegistry::RegisterCompilationKernels();
|
||||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
|
||||||
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
// Only check for compilability if the MLIR bridge is not enabled.
|
||||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
||||||
uncompilable_node_info;
|
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
||||||
for (const auto& it : uncompilable_nodes_map) {
|
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
||||||
for (const auto& info : it.second.second) {
|
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||||
uncompilable_node_info.emplace_back(info);
|
uncompilable_node_info;
|
||||||
|
for (const auto& it : uncompilable_nodes_map) {
|
||||||
|
for (const auto& info : it.second.second) {
|
||||||
|
uncompilable_node_info.emplace_back(info);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
string message = absl::StrCat(
|
||||||
string message = absl::StrCat(
|
"Function invoked by the following node is not compilable: ",
|
||||||
"Function invoked by the following node is not compilable: ",
|
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
|
||||||
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
|
absl::StrAppend(&message, "Uncompilable nodes:");
|
||||||
absl::StrAppend(&message, "Uncompilable nodes:");
|
for (const auto& node_info : uncompilable_node_info) {
|
||||||
for (const auto& node_info : uncompilable_node_info) {
|
string node_message = absl::StrCat("\n", node_info.name, ": ",
|
||||||
string node_message =
|
node_info.uncompilable_reason, "\n",
|
||||||
absl::StrCat("\n", node_info.name, ": ",
|
"\tStacktrace:\n");
|
||||||
node_info.uncompilable_reason, "\n", "\tStacktrace:\n");
|
for (const auto& stack_frame : node_info.stack_trace) {
|
||||||
for (const auto& stack_frame : node_info.stack_trace) {
|
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
|
||||||
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
|
stack_frame.name, stack_frame.function_name);
|
||||||
stack_frame.name, stack_frame.function_name);
|
}
|
||||||
|
absl::StrAppend(&message, node_message);
|
||||||
}
|
}
|
||||||
absl::StrAppend(&message, node_message);
|
VLOG(1) << message;
|
||||||
|
return errors::InvalidArgument(message);
|
||||||
}
|
}
|
||||||
VLOG(1) << message;
|
|
||||||
return errors::InvalidArgument(message);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get function body, constant args, and resource args.
|
// Get function body, constant args, and resource args.
|
||||||
|
@ -123,7 +123,6 @@ tf_xla_py_test(
|
|||||||
name = "adagrad_da_test",
|
name = "adagrad_da_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["adagrad_da_test.py"],
|
srcs = ["adagrad_da_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
@ -161,7 +160,6 @@ tf_xla_py_test(
|
|||||||
srcs = ["add_n_test.py"],
|
srcs = ["add_n_test.py"],
|
||||||
# TensorList ops are not implemented in the on-demand compilation model yet.
|
# TensorList ops are not implemented in the on-demand compilation model yet.
|
||||||
disabled_backends = ["cpu_ondemand"],
|
disabled_backends = ["cpu_ondemand"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
@ -689,7 +687,6 @@ tf_xla_py_test(
|
|||||||
name = "fft_test",
|
name = "fft_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["fft_test.py"],
|
srcs = ["fft_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 6,
|
shard_count = 6,
|
||||||
tags = [
|
tags = [
|
||||||
@ -929,7 +926,6 @@ tf_xla_py_test(
|
|||||||
name = "pooling_ops_test",
|
name = "pooling_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["pooling_ops_test.py"],
|
srcs = ["pooling_ops_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
tags = [
|
tags = [
|
||||||
@ -1243,7 +1239,6 @@ tf_xla_py_test(
|
|||||||
name = "stack_ops_test",
|
name = "stack_ops_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["stack_ops_test.py"],
|
srcs = ["stack_ops_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"config-cuda-only",
|
"config-cuda-only",
|
||||||
@ -1304,7 +1299,6 @@ tf_xla_py_test(
|
|||||||
srcs = ["tensor_array_ops_test.py"],
|
srcs = ["tensor_array_ops_test.py"],
|
||||||
# TensorArray ops are not implemented in the on-demand compilation model yet.
|
# TensorArray ops are not implemented in the on-demand compilation model yet.
|
||||||
disabled_backends = ["cpu_ondemand"],
|
disabled_backends = ["cpu_ondemand"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"config-cuda-only",
|
"config-cuda-only",
|
||||||
@ -1333,7 +1327,6 @@ tf_xla_py_test(
|
|||||||
srcs = ["tensor_list_ops_test.py"],
|
srcs = ["tensor_list_ops_test.py"],
|
||||||
# TensorList ops are not implemented in the on-demand compilation model yet.
|
# TensorList ops are not implemented in the on-demand compilation model yet.
|
||||||
disabled_backends = ["cpu_ondemand"],
|
disabled_backends = ["cpu_ondemand"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
@ -1889,7 +1882,6 @@ tf_xla_py_test(
|
|||||||
name = "special_math_test",
|
name = "special_math_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["special_math_test.py"],
|
srcs = ["special_math_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
shard_count = 5,
|
shard_count = 5,
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/defs.h"
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
#include "tensorflow/compiler/jit/flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
||||||
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
|
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
@ -52,6 +53,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||||
#include "tensorflow/core/util/dump_graph.h"
|
#include "tensorflow/core/util/dump_graph.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -726,8 +728,18 @@ Status XlaCompiler::CompileFunction(
|
|||||||
}
|
}
|
||||||
|
|
||||||
VLOG(1) << "====================================================";
|
VLOG(1) << "====================================================";
|
||||||
TF_RETURN_IF_ERROR(
|
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
||||||
CompileGraph(options, function_id, std::move(graph), args, result));
|
VLOG(1) << "Using MLIR bridge";
|
||||||
|
GraphDebugInfo debug_info;
|
||||||
|
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
|
||||||
|
std::move(*graph), {args.data(), args.size()},
|
||||||
|
options_.device_type.type_string(), options.use_tuple_arg,
|
||||||
|
*options_.flib_def, debug_info, options_.shape_representation_fn,
|
||||||
|
result));
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CompileGraph(options, function_id, std::move(graph), args, result));
|
||||||
|
}
|
||||||
VLOG(1) << "====================================================";
|
VLOG(1) << "====================================================";
|
||||||
|
|
||||||
cache_[{function_id, arg_vector}] = *result;
|
cache_[{function_id, arg_vector}] = *result;
|
||||||
|
Loading…
Reference in New Issue
Block a user