Internal change
PiperOrigin-RevId: 299186503 Change-Id: Ic39e151185e5e98636b2b71a9812051c0c357f5c
This commit is contained in:
parent
4cedd7e86d
commit
244a9b0a41
@ -64,6 +64,7 @@ cc_library(
|
|||||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||||
"@llvm-project//llvm:target",
|
"@llvm-project//llvm:target",
|
||||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||||
|
"//tensorflow/core:regexp_internal",
|
||||||
] + if_llvm_aarch64_available([
|
] + if_llvm_aarch64_available([
|
||||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||||
]),
|
]),
|
||||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/regexp.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -105,8 +106,9 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
|||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
xla::XlaComputation computation;
|
xla::XlaComputation computation;
|
||||||
if (flags.mlir_components == "Bridge") {
|
if (flags.mlir_components == "Bridge") {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir(
|
||||||
ConvertGraphDefToXlaViaMlir(graph_def, config, &computation));
|
graph_def, config, &computation, flags.debug_info,
|
||||||
|
flags.debug_info_path_begin_marker));
|
||||||
} else if (flags.mlir_components.empty() || flags.mlir_components == "None") {
|
} else if (flags.mlir_components.empty() || flags.mlir_components == "None") {
|
||||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
||||||
client, &computation));
|
client, &computation));
|
||||||
@ -166,6 +168,23 @@ static void InitializeTargets() {
|
|||||||
LLVMInitializeX86AsmPrinter();
|
LLVMInitializeX86AsmPrinter();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Replaces {{tag.type tag.name}} in the error message with tag_name.
|
||||||
|
// TODO(bixia): We currently only handlge tag.type == "node".
|
||||||
|
//
|
||||||
|
// In the error message, a graph node is represented as {{tag.type, tag.name}},
|
||||||
|
// to allow a Python debugger to insert source information about the graph node.
|
||||||
|
// For example, a Python add expression may be represented as
|
||||||
|
// {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate
|
||||||
|
// in tensorflow/python/framework/error_interpolation.py for more detail.
|
||||||
|
static std::string InterpolateErrorMessage(std::string message) {
|
||||||
|
// See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py
|
||||||
|
// Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix".
|
||||||
|
static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"};
|
||||||
|
RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3");
|
||||||
|
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
|
||||||
Status Main(const MainFlags& flags) {
|
Status Main(const MainFlags& flags) {
|
||||||
absl::call_once(targets_init, &InitializeTargets);
|
absl::call_once(targets_init, &InitializeTargets);
|
||||||
|
|
||||||
@ -192,8 +211,13 @@ Status Main(const MainFlags& flags) {
|
|||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||||
CompileResult compile_result;
|
CompileResult compile_result;
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
CompileGraph(std::move(graph_def), config, flags, &compile_result));
|
Status status =
|
||||||
|
CompileGraph(std::move(graph_def), config, flags, &compile_result);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return Status(status.code(),
|
||||||
|
InterpolateErrorMessage(status.error_message()));
|
||||||
|
}
|
||||||
|
|
||||||
// Write output files.
|
// Write output files.
|
||||||
Env* env = Env::Default();
|
Env* env = Env::Default();
|
||||||
|
@ -24,6 +24,13 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
|||||||
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
|
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
|
||||||
"be in the human-readable proto text format, otherwise it is expected "
|
"be in the human-readable proto text format, otherwise it is expected "
|
||||||
"to be in the proto binary format."},
|
"to be in the proto binary format."},
|
||||||
|
{"debug_info", &flags->debug_info,
|
||||||
|
"Graph debug info file. If the file ends in '.pbtxt' it is expected to "
|
||||||
|
"be in the human-readable proto text format, otherwise it is expected "
|
||||||
|
"to be in the proto binary format."},
|
||||||
|
{"debug_info_path_begin_marker", &flags->debug_info_path_begin_marker,
|
||||||
|
"If not none, only keep the file path in the debug information after the"
|
||||||
|
" marker. The default value is empty"},
|
||||||
{"config", &flags->config,
|
{"config", &flags->config,
|
||||||
"Input file containing Config proto. If the file ends in '.pbtxt' it "
|
"Input file containing Config proto. If the file ends in '.pbtxt' it "
|
||||||
"is expected to be in the human-readable proto text format, otherwise "
|
"is expected to be in the human-readable proto text format, otherwise "
|
||||||
|
@ -28,6 +28,8 @@ namespace tfcompile {
|
|||||||
|
|
||||||
struct MainFlags {
|
struct MainFlags {
|
||||||
string graph;
|
string graph;
|
||||||
|
string debug_info;
|
||||||
|
string debug_info_path_begin_marker;
|
||||||
string config;
|
string config;
|
||||||
bool dump_fetch_nodes = false;
|
bool dump_fetch_nodes = false;
|
||||||
string target_triple;
|
string target_triple;
|
||||||
|
@ -1,11 +1,34 @@
|
|||||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||||
|
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:private"],
|
default_visibility = ["//visibility:private"],
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
glob_lit_tests(
|
||||||
|
data = [":filecheck_test_utilities"],
|
||||||
|
driver = "@llvm-project//mlir:run_lit.sh",
|
||||||
|
test_file_exts = ["lit.pbtxt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bundle together all of the test utilities that are used by tests.
|
||||||
|
filegroup(
|
||||||
|
name = "filecheck_test_utilities",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"test_error_message.lit.pbtxt.config.pbtxt",
|
||||||
|
"test_error_message.lit.pbtxt.debug.pbtxt",
|
||||||
|
"test_error_message.lit.pbtxt.fake_py.debug",
|
||||||
|
],
|
||||||
|
data = [
|
||||||
|
"//tensorflow/compiler/aot:tfcompile",
|
||||||
|
"@llvm-project//llvm:FileCheck",
|
||||||
|
"@llvm-project//llvm:not",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# We disable some tfcompile tests in the open source build with the
|
# We disable some tfcompile tests in the open source build with the
|
||||||
# "manual" tag to avoid making our OSS users build LLVM twice
|
# "manual" tag to avoid making our OSS users build LLVM twice
|
||||||
# (once for host and once for target).
|
# (once for host and once for target).
|
||||||
@ -60,6 +83,7 @@ genrule(
|
|||||||
testonly = 1,
|
testonly = 1,
|
||||||
outs = [
|
outs = [
|
||||||
"test_graph_tfadd.pb",
|
"test_graph_tfadd.pb",
|
||||||
|
"test_debuginfo_tfadd.pb",
|
||||||
"test_graph_tfadd_with_ckpt.ckpt",
|
"test_graph_tfadd_with_ckpt.ckpt",
|
||||||
"test_graph_tfadd_with_ckpt.pb",
|
"test_graph_tfadd_with_ckpt.pb",
|
||||||
"test_graph_tfadd_with_ckpt_saver.ckpt",
|
"test_graph_tfadd_with_ckpt_saver.ckpt",
|
||||||
@ -317,6 +341,7 @@ tf_library(
|
|||||||
testonly = 1,
|
testonly = 1,
|
||||||
config = "test_graph_tfadd.config.pbtxt",
|
config = "test_graph_tfadd.config.pbtxt",
|
||||||
cpp_class = "AddComp",
|
cpp_class = "AddComp",
|
||||||
|
debug_info = "test_debuginfo_tfadd.pb",
|
||||||
graph = "test_graph_tfadd.pb",
|
graph = "test_graph_tfadd.pb",
|
||||||
include_standard_runtime_deps = False,
|
include_standard_runtime_deps = False,
|
||||||
mlir_components = "Bridge",
|
mlir_components = "Bridge",
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.core.protobuf import saver_pb2
|
|||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import error_interpolation
|
||||||
from tensorflow.python.framework import function
|
from tensorflow.python.framework import function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -184,7 +185,22 @@ def tfvariable_sequential_updates(_):
|
|||||||
array_ops.identity(updates, name='result')
|
array_ops.identity(updates, name='result')
|
||||||
|
|
||||||
|
|
||||||
def write_graph(build_graph, out_dir):
|
def export_debug_info(exported_graph):
|
||||||
|
"""Exports debug information from a graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exported_graph: A Graph that has been created by tracing a saveable view.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Corresponding GraphDebugInfo with traces for all ops in exported_graph.
|
||||||
|
"""
|
||||||
|
exported_operations = []
|
||||||
|
for op in exported_graph.get_operations():
|
||||||
|
exported_operations.append(('', op))
|
||||||
|
return error_interpolation.create_graph_debug_info_def(exported_operations)
|
||||||
|
|
||||||
|
|
||||||
|
def write_graph(build_graph, out_dir, debug_info=False):
|
||||||
"""Build a graph using build_graph and write it out."""
|
"""Build a graph using build_graph and write it out."""
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
@ -193,10 +209,19 @@ def write_graph(build_graph, out_dir):
|
|||||||
with open(filename, 'wb') as f:
|
with open(filename, 'wb') as f:
|
||||||
f.write(six.ensure_binary(g.as_graph_def().SerializeToString()))
|
f.write(six.ensure_binary(g.as_graph_def().SerializeToString()))
|
||||||
|
|
||||||
|
if debug_info:
|
||||||
|
filename_debuginfo = os.path.join(
|
||||||
|
out_dir, 'test_debuginfo_%s.pb' % build_graph.__name__)
|
||||||
|
test_debuginfo = export_debug_info(g)
|
||||||
|
with open(filename_debuginfo, 'wb') as f:
|
||||||
|
f.write(
|
||||||
|
six.ensure_binary(
|
||||||
|
test_debuginfo.SerializeToString(deterministic=True)))
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
control_flow_util.enable_control_flow_v2()
|
control_flow_util.enable_control_flow_v2()
|
||||||
write_graph(tfadd, FLAGS.out_dir)
|
write_graph(tfadd, FLAGS.out_dir, debug_info=True)
|
||||||
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
||||||
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
||||||
write_graph(tfassert_eq, FLAGS.out_dir)
|
write_graph(tfassert_eq, FLAGS.out_dir)
|
||||||
|
69
tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt
Normal file
69
tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
|
||||||
|
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
|
||||||
|
|
||||||
|
# Checks the error message produced by tfcompile with mlir_component
|
||||||
|
# Checks that source debug information is used in the output error message and
|
||||||
|
# the node x_y_sum = Add
|
||||||
|
# CHECK: INVALID ARGUMENTS: Dimensions must be equal, but are 2 and 3 for 'x_y_sum = Add[T=DT_INT32](aot_feed_0/x, aot_feed_0/y)'
|
||||||
|
# CHECK: math_ops.add(x, y, name='x_y_sum')
|
||||||
|
# CHECK: build_graph(out_dir)
|
||||||
|
|
||||||
|
# Checks the error message produced by tfcompile without mlir_component
|
||||||
|
# OLD: INVALID ARGUMENTS: Incompatible shapes: [2] vs. [3]
|
||||||
|
# OLD: x_y_sum
|
||||||
|
|
||||||
|
node: {
|
||||||
|
name: "x"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr: {
|
||||||
|
key: "shape"
|
||||||
|
value: {
|
||||||
|
shape: {
|
||||||
|
dim: {
|
||||||
|
size: -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr: {
|
||||||
|
key: "dtype"
|
||||||
|
value: {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node: {
|
||||||
|
name: "y"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr: {
|
||||||
|
key: "shape"
|
||||||
|
value: {
|
||||||
|
shape: {
|
||||||
|
dim: {
|
||||||
|
size: -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr: {
|
||||||
|
key: "dtype"
|
||||||
|
value: {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node: {
|
||||||
|
name: "x_y_sum"
|
||||||
|
op: "Add"
|
||||||
|
input: "x"
|
||||||
|
input: "y"
|
||||||
|
attr: {
|
||||||
|
key: "T"
|
||||||
|
value: {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versions: {
|
||||||
|
producer: 321
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
# Text form of tensorflow.tf2xla.Config proto.
|
||||||
|
feed {
|
||||||
|
id { node_name: "x" }
|
||||||
|
shape {
|
||||||
|
dim { size: 2 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
feed {
|
||||||
|
id { node_name: "y" }
|
||||||
|
shape {
|
||||||
|
dim { size: 3 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fetch {
|
||||||
|
id { node_name: "x_y_sum" }
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
files: "org_tensorflow/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug"
|
||||||
|
traces: {
|
||||||
|
key: "x@"
|
||||||
|
value: {
|
||||||
|
file_line_cols: {
|
||||||
|
line: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
traces: {
|
||||||
|
key: "x_y_sum@"
|
||||||
|
value: {
|
||||||
|
file_line_cols: {
|
||||||
|
line: 3
|
||||||
|
}
|
||||||
|
file_line_cols: {
|
||||||
|
line: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
traces: {
|
||||||
|
key: "y@"
|
||||||
|
value: {
|
||||||
|
file_line_cols: {
|
||||||
|
line: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
x = value
|
||||||
|
y = value
|
||||||
|
math_ops.add(x, y, name='x_y_sum')
|
||||||
|
build_graph(out_dir)
|
@ -26,6 +26,7 @@ def tf_library(
|
|||||||
name,
|
name,
|
||||||
graph,
|
graph,
|
||||||
config,
|
config,
|
||||||
|
debug_info = None,
|
||||||
freeze_checkpoint = None,
|
freeze_checkpoint = None,
|
||||||
freeze_saver = None,
|
freeze_saver = None,
|
||||||
cpp_class = None,
|
cpp_class = None,
|
||||||
@ -191,12 +192,15 @@ def tf_library(
|
|||||||
|
|
||||||
mlir_flag = "--mlir_components=" + mlir_components
|
mlir_flag = "--mlir_components=" + mlir_components
|
||||||
|
|
||||||
|
srcs = [tfcompile_graph, config]
|
||||||
|
debug_info_flag = ""
|
||||||
|
if debug_info:
|
||||||
|
srcs.append(debug_info)
|
||||||
|
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
|
||||||
|
|
||||||
native.genrule(
|
native.genrule(
|
||||||
name = ("gen_" + name),
|
name = ("gen_" + name),
|
||||||
srcs = [
|
srcs = srcs,
|
||||||
tfcompile_graph,
|
|
||||||
config,
|
|
||||||
],
|
|
||||||
outs = [
|
outs = [
|
||||||
header_file,
|
header_file,
|
||||||
metadata_object_file,
|
metadata_object_file,
|
||||||
@ -206,6 +210,7 @@ def tf_library(
|
|||||||
"CUDA_VISIBLE_DEVICES='' " +
|
"CUDA_VISIBLE_DEVICES='' " +
|
||||||
"$(location " + tfcompile_tool + ")" +
|
"$(location " + tfcompile_tool + ")" +
|
||||||
" --graph=$(location " + tfcompile_graph + ")" +
|
" --graph=$(location " + tfcompile_graph + ")" +
|
||||||
|
debug_info_flag +
|
||||||
" --config=$(location " + config + ")" +
|
" --config=$(location " + config + ")" +
|
||||||
" --entry_point=" + ep +
|
" --entry_point=" + ep +
|
||||||
" --cpp_class=" + cpp_class +
|
" --cpp_class=" + cpp_class +
|
||||||
@ -237,10 +242,7 @@ def tf_library(
|
|||||||
session_module_pb = name + "_session_module.pb"
|
session_module_pb = name + "_session_module.pb"
|
||||||
native.genrule(
|
native.genrule(
|
||||||
name = (name + "_session_module"),
|
name = (name + "_session_module"),
|
||||||
srcs = [
|
srcs = srcs,
|
||||||
tfcompile_graph,
|
|
||||||
config,
|
|
||||||
],
|
|
||||||
outs = [
|
outs = [
|
||||||
session_module_pb,
|
session_module_pb,
|
||||||
],
|
],
|
||||||
@ -248,6 +250,7 @@ def tf_library(
|
|||||||
"CUDA_VISIBLE_DEVICES='' " +
|
"CUDA_VISIBLE_DEVICES='' " +
|
||||||
"$(location " + tfcompile_tool + ")" +
|
"$(location " + tfcompile_tool + ")" +
|
||||||
" --graph=$(location " + tfcompile_graph + ")" +
|
" --graph=$(location " + tfcompile_graph + ")" +
|
||||||
|
debug_info_flag +
|
||||||
" --config=$(location " + config + ")" +
|
" --config=$(location " + config + ")" +
|
||||||
" --entry_point=" + ep +
|
" --entry_point=" + ep +
|
||||||
" --cpp_class=" + cpp_class +
|
" --cpp_class=" + cpp_class +
|
||||||
|
@ -65,6 +65,7 @@ int main(int argc, char** argv) {
|
|||||||
flags.out_metadata_object = "out_helper.o";
|
flags.out_metadata_object = "out_helper.o";
|
||||||
flags.out_header = "out.h";
|
flags.out_header = "out.h";
|
||||||
flags.entry_point = "entry";
|
flags.entry_point = "entry";
|
||||||
|
flags.debug_info_path_begin_marker = "";
|
||||||
|
|
||||||
std::vector<tensorflow::Flag> flag_list;
|
std::vector<tensorflow::Flag> flag_list;
|
||||||
AppendMainFlags(&flag_list, &flags);
|
AppendMainFlags(&flag_list, &flags);
|
||||||
|
@ -177,8 +177,8 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
|
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
|
||||||
"//tensorflow/compiler/xla/client",
|
"//tensorflow/compiler/xla/client",
|
||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -86,9 +86,10 @@ Status ConvertOutputInfo(const tf2xla::Config& config,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status ConvertGraphDefToXlaViaMlir(GraphDef graph_def,
|
Status ConvertGraphDefToXlaViaMlir(
|
||||||
const tf2xla::Config& config,
|
GraphDef graph_def, const tf2xla::Config& config,
|
||||||
xla::XlaComputation* computation) {
|
xla::XlaComputation* computation, absl::string_view debug_info_filename,
|
||||||
|
absl::string_view debug_info_path_begin_marker) {
|
||||||
// AddPlaceholdersForFeeds prepares for PruneGraphDefInto and serves two
|
// AddPlaceholdersForFeeds prepares for PruneGraphDefInto and serves two
|
||||||
// purposes: (1) It creates a placeholder node for each feed, so that
|
// purposes: (1) It creates a placeholder node for each feed, so that
|
||||||
// PruneGraphDefInfo can prune away the node containing the feed. (2) It
|
// PruneGraphDefInfo can prune away the node containing the feed. (2) It
|
||||||
@ -115,7 +116,24 @@ Status ConvertGraphDefToXlaViaMlir(GraphDef graph_def,
|
|||||||
TF_RETURN_IF_ERROR(ConvertOutputInfo(config, &specs));
|
TF_RETURN_IF_ERROR(ConvertOutputInfo(config, &specs));
|
||||||
|
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
|
if (!debug_info_filename.empty()) {
|
||||||
|
TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_filename, &debug_info));
|
||||||
|
|
||||||
|
if (!debug_info_path_begin_marker.empty()) {
|
||||||
|
for (size_t i = 0, e = debug_info.files_size(); i < e; ++i) {
|
||||||
|
std::string* file_name = debug_info.mutable_files(i);
|
||||||
|
size_t location =
|
||||||
|
file_name->rfind(std::string(debug_info_path_begin_marker));
|
||||||
|
if (location != -1) {
|
||||||
|
*file_name = file_name->substr(location +
|
||||||
|
debug_info_path_begin_marker.length());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
mlir::OwningModuleRef module,
|
mlir::OwningModuleRef module,
|
||||||
ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context));
|
ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context));
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
|
||||||
#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
|
#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||||
#include "tensorflow/compiler/xla/client/client.h"
|
#include "tensorflow/compiler/xla/client/client.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
@ -34,10 +35,16 @@ Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config,
|
|||||||
xla::Client* client,
|
xla::Client* client,
|
||||||
xla::XlaComputation* computation);
|
xla::XlaComputation* computation);
|
||||||
|
|
||||||
// Similar to ConvertGraphDefToXla, but uses MLIR.
|
// Similar to ConvertGraphDefToXla, but uses MLIR and handle debug information.
|
||||||
Status ConvertGraphDefToXlaViaMlir(GraphDef graph_def,
|
//
|
||||||
const tf2xla::Config& config,
|
// debug_info_filename: the file for the debug information proto.
|
||||||
xla::XlaComputation* computation);
|
// debug_info_path_begin_marker: if not empty, file pathes in the debug
|
||||||
|
// information are trimmed from the beginning to the first appearance of the
|
||||||
|
// marker.
|
||||||
|
Status ConvertGraphDefToXlaViaMlir(
|
||||||
|
GraphDef graph_def, const tf2xla::Config& config,
|
||||||
|
xla::XlaComputation* computation, absl::string_view debug_info_filename,
|
||||||
|
absl::string_view debug_info_path_begin_marker);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user