Internal change
PiperOrigin-RevId: 299186503 Change-Id: Ic39e151185e5e98636b2b71a9812051c0c357f5c
This commit is contained in:
parent
4cedd7e86d
commit
244a9b0a41
@ -64,6 +64,7 @@ cc_library(
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
"//tensorflow/core:regexp_internal",
|
||||
] + if_llvm_aarch64_available([
|
||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||
]),
|
||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -105,8 +106,9 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
.ValueOrDie();
|
||||
xla::XlaComputation computation;
|
||||
if (flags.mlir_components == "Bridge") {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToXlaViaMlir(graph_def, config, &computation));
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir(
|
||||
graph_def, config, &computation, flags.debug_info,
|
||||
flags.debug_info_path_begin_marker));
|
||||
} else if (flags.mlir_components.empty() || flags.mlir_components == "None") {
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
||||
client, &computation));
|
||||
@ -166,6 +168,23 @@ static void InitializeTargets() {
|
||||
LLVMInitializeX86AsmPrinter();
|
||||
}
|
||||
|
||||
// Replaces {{tag.type tag.name}} in the error message with tag_name.
|
||||
// TODO(bixia): We currently only handlge tag.type == "node".
|
||||
//
|
||||
// In the error message, a graph node is represented as {{tag.type, tag.name}},
|
||||
// to allow a Python debugger to insert source information about the graph node.
|
||||
// For example, a Python add expression may be represented as
|
||||
// {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate
|
||||
// in tensorflow/python/framework/error_interpolation.py for more detail.
|
||||
static std::string InterpolateErrorMessage(std::string message) {
|
||||
// See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py
|
||||
// Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix".
|
||||
static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"};
|
||||
RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3");
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
Status Main(const MainFlags& flags) {
|
||||
absl::call_once(targets_init, &InitializeTargets);
|
||||
|
||||
@ -192,8 +211,13 @@ Status Main(const MainFlags& flags) {
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||
CompileResult compile_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraph(std::move(graph_def), config, flags, &compile_result));
|
||||
|
||||
Status status =
|
||||
CompileGraph(std::move(graph_def), config, flags, &compile_result);
|
||||
if (!status.ok()) {
|
||||
return Status(status.code(),
|
||||
InterpolateErrorMessage(status.error_message()));
|
||||
}
|
||||
|
||||
// Write output files.
|
||||
Env* env = Env::Default();
|
||||
|
@ -24,6 +24,13 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
||||
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
|
||||
"be in the human-readable proto text format, otherwise it is expected "
|
||||
"to be in the proto binary format."},
|
||||
{"debug_info", &flags->debug_info,
|
||||
"Graph debug info file. If the file ends in '.pbtxt' it is expected to "
|
||||
"be in the human-readable proto text format, otherwise it is expected "
|
||||
"to be in the proto binary format."},
|
||||
{"debug_info_path_begin_marker", &flags->debug_info_path_begin_marker,
|
||||
"If not none, only keep the file path in the debug information after the"
|
||||
" marker. The default value is empty"},
|
||||
{"config", &flags->config,
|
||||
"Input file containing Config proto. If the file ends in '.pbtxt' it "
|
||||
"is expected to be in the human-readable proto text format, otherwise "
|
||||
|
@ -28,6 +28,8 @@ namespace tfcompile {
|
||||
|
||||
struct MainFlags {
|
||||
string graph;
|
||||
string debug_info;
|
||||
string debug_info_path_begin_marker;
|
||||
string config;
|
||||
bool dump_fetch_nodes = false;
|
||||
string target_triple;
|
||||
|
@ -1,11 +1,34 @@
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":filecheck_test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["lit.pbtxt"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "filecheck_test_utilities",
|
||||
testonly = True,
|
||||
srcs = [
|
||||
"test_error_message.lit.pbtxt.config.pbtxt",
|
||||
"test_error_message.lit.pbtxt.debug.pbtxt",
|
||||
"test_error_message.lit.pbtxt.fake_py.debug",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/compiler/aot:tfcompile",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
# We disable some tfcompile tests in the open source build with the
|
||||
# "manual" tag to avoid making our OSS users build LLVM twice
|
||||
# (once for host and once for target).
|
||||
@ -60,6 +83,7 @@ genrule(
|
||||
testonly = 1,
|
||||
outs = [
|
||||
"test_graph_tfadd.pb",
|
||||
"test_debuginfo_tfadd.pb",
|
||||
"test_graph_tfadd_with_ckpt.ckpt",
|
||||
"test_graph_tfadd_with_ckpt.pb",
|
||||
"test_graph_tfadd_with_ckpt_saver.ckpt",
|
||||
@ -317,6 +341,7 @@ tf_library(
|
||||
testonly = 1,
|
||||
config = "test_graph_tfadd.config.pbtxt",
|
||||
cpp_class = "AddComp",
|
||||
debug_info = "test_debuginfo_tfadd.pb",
|
||||
graph = "test_graph_tfadd.pb",
|
||||
include_standard_runtime_deps = False,
|
||||
mlir_components = "Bridge",
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import error_interpolation
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -184,7 +185,22 @@ def tfvariable_sequential_updates(_):
|
||||
array_ops.identity(updates, name='result')
|
||||
|
||||
|
||||
def write_graph(build_graph, out_dir):
|
||||
def export_debug_info(exported_graph):
|
||||
"""Exports debug information from a graph.
|
||||
|
||||
Args:
|
||||
exported_graph: A Graph that has been created by tracing a saveable view.
|
||||
|
||||
Returns:
|
||||
Corresponding GraphDebugInfo with traces for all ops in exported_graph.
|
||||
"""
|
||||
exported_operations = []
|
||||
for op in exported_graph.get_operations():
|
||||
exported_operations.append(('', op))
|
||||
return error_interpolation.create_graph_debug_info_def(exported_operations)
|
||||
|
||||
|
||||
def write_graph(build_graph, out_dir, debug_info=False):
|
||||
"""Build a graph using build_graph and write it out."""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
@ -193,10 +209,19 @@ def write_graph(build_graph, out_dir):
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(six.ensure_binary(g.as_graph_def().SerializeToString()))
|
||||
|
||||
if debug_info:
|
||||
filename_debuginfo = os.path.join(
|
||||
out_dir, 'test_debuginfo_%s.pb' % build_graph.__name__)
|
||||
test_debuginfo = export_debug_info(g)
|
||||
with open(filename_debuginfo, 'wb') as f:
|
||||
f.write(
|
||||
six.ensure_binary(
|
||||
test_debuginfo.SerializeToString(deterministic=True)))
|
||||
|
||||
|
||||
def main(_):
|
||||
control_flow_util.enable_control_flow_v2()
|
||||
write_graph(tfadd, FLAGS.out_dir)
|
||||
write_graph(tfadd, FLAGS.out_dir, debug_info=True)
|
||||
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
||||
write_graph(tfassert_eq, FLAGS.out_dir)
|
||||
|
69
tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt
Normal file
69
tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt
Normal file
@ -0,0 +1,69 @@
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
|
||||
|
||||
# Checks the error message produced by tfcompile with mlir_component
|
||||
# Checks that source debug information is used in the output error message and
|
||||
# the node x_y_sum = Add
|
||||
# CHECK: INVALID ARGUMENTS: Dimensions must be equal, but are 2 and 3 for 'x_y_sum = Add[T=DT_INT32](aot_feed_0/x, aot_feed_0/y)'
|
||||
# CHECK: math_ops.add(x, y, name='x_y_sum')
|
||||
# CHECK: build_graph(out_dir)
|
||||
|
||||
# Checks the error message produced by tfcompile without mlir_component
|
||||
# OLD: INVALID ARGUMENTS: Incompatible shapes: [2] vs. [3]
|
||||
# OLD: x_y_sum
|
||||
|
||||
node: {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr: {
|
||||
key: "shape"
|
||||
value: {
|
||||
shape: {
|
||||
dim: {
|
||||
size: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "y"
|
||||
op: "Placeholder"
|
||||
attr: {
|
||||
key: "shape"
|
||||
value: {
|
||||
shape: {
|
||||
dim: {
|
||||
size: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "x_y_sum"
|
||||
op: "Add"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr: {
|
||||
key: "T"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions: {
|
||||
producer: 321
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
# Text form of tensorflow.tf2xla.Config proto.
|
||||
feed {
|
||||
id { node_name: "x" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "y" }
|
||||
shape {
|
||||
dim { size: 3 }
|
||||
}
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_y_sum" }
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
files: "org_tensorflow/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug"
|
||||
traces: {
|
||||
key: "x@"
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
traces: {
|
||||
key: "x_y_sum@"
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 3
|
||||
}
|
||||
file_line_cols: {
|
||||
line: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
traces: {
|
||||
key: "y@"
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 2
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
x = value
|
||||
y = value
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
build_graph(out_dir)
|
@ -26,6 +26,7 @@ def tf_library(
|
||||
name,
|
||||
graph,
|
||||
config,
|
||||
debug_info = None,
|
||||
freeze_checkpoint = None,
|
||||
freeze_saver = None,
|
||||
cpp_class = None,
|
||||
@ -191,12 +192,15 @@ def tf_library(
|
||||
|
||||
mlir_flag = "--mlir_components=" + mlir_components
|
||||
|
||||
srcs = [tfcompile_graph, config]
|
||||
debug_info_flag = ""
|
||||
if debug_info:
|
||||
srcs.append(debug_info)
|
||||
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
|
||||
|
||||
native.genrule(
|
||||
name = ("gen_" + name),
|
||||
srcs = [
|
||||
tfcompile_graph,
|
||||
config,
|
||||
],
|
||||
srcs = srcs,
|
||||
outs = [
|
||||
header_file,
|
||||
metadata_object_file,
|
||||
@ -206,6 +210,7 @@ def tf_library(
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
debug_info_flag +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --entry_point=" + ep +
|
||||
" --cpp_class=" + cpp_class +
|
||||
@ -237,10 +242,7 @@ def tf_library(
|
||||
session_module_pb = name + "_session_module.pb"
|
||||
native.genrule(
|
||||
name = (name + "_session_module"),
|
||||
srcs = [
|
||||
tfcompile_graph,
|
||||
config,
|
||||
],
|
||||
srcs = srcs,
|
||||
outs = [
|
||||
session_module_pb,
|
||||
],
|
||||
@ -248,6 +250,7 @@ def tf_library(
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
debug_info_flag +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --entry_point=" + ep +
|
||||
" --cpp_class=" + cpp_class +
|
||||
|
@ -65,6 +65,7 @@ int main(int argc, char** argv) {
|
||||
flags.out_metadata_object = "out_helper.o";
|
||||
flags.out_header = "out.h";
|
||||
flags.entry_point = "entry";
|
||||
flags.debug_info_path_begin_marker = "";
|
||||
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
|
@ -177,8 +177,8 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
|
||||
"//tensorflow/compiler/xla/client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -86,9 +86,10 @@ Status ConvertOutputInfo(const tf2xla::Config& config,
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ConvertGraphDefToXlaViaMlir(GraphDef graph_def,
|
||||
const tf2xla::Config& config,
|
||||
xla::XlaComputation* computation) {
|
||||
Status ConvertGraphDefToXlaViaMlir(
|
||||
GraphDef graph_def, const tf2xla::Config& config,
|
||||
xla::XlaComputation* computation, absl::string_view debug_info_filename,
|
||||
absl::string_view debug_info_path_begin_marker) {
|
||||
// AddPlaceholdersForFeeds prepares for PruneGraphDefInto and serves two
|
||||
// purposes: (1) It creates a placeholder node for each feed, so that
|
||||
// PruneGraphDefInfo can prune away the node containing the feed. (2) It
|
||||
@ -115,7 +116,24 @@ Status ConvertGraphDefToXlaViaMlir(GraphDef graph_def,
|
||||
TF_RETURN_IF_ERROR(ConvertOutputInfo(config, &specs));
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
if (!debug_info_filename.empty()) {
|
||||
TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_filename, &debug_info));
|
||||
|
||||
if (!debug_info_path_begin_marker.empty()) {
|
||||
for (size_t i = 0, e = debug_info.files_size(); i < e; ++i) {
|
||||
std::string* file_name = debug_info.mutable_files(i);
|
||||
size_t location =
|
||||
file_name->rfind(std::string(debug_info_path_begin_marker));
|
||||
if (location != -1) {
|
||||
*file_name = file_name->substr(location +
|
||||
debug_info_path_begin_marker.length());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mlir::MLIRContext context;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::OwningModuleRef module,
|
||||
ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context));
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/client/client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
@ -34,10 +35,16 @@ Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config,
|
||||
xla::Client* client,
|
||||
xla::XlaComputation* computation);
|
||||
|
||||
// Similar to ConvertGraphDefToXla, but uses MLIR.
|
||||
Status ConvertGraphDefToXlaViaMlir(GraphDef graph_def,
|
||||
const tf2xla::Config& config,
|
||||
xla::XlaComputation* computation);
|
||||
// Similar to ConvertGraphDefToXla, but uses MLIR and handle debug information.
|
||||
//
|
||||
// debug_info_filename: the file for the debug information proto.
|
||||
// debug_info_path_begin_marker: if not empty, file pathes in the debug
|
||||
// information are trimmed from the beginning to the first appearance of the
|
||||
// marker.
|
||||
Status ConvertGraphDefToXlaViaMlir(
|
||||
GraphDef graph_def, const tf2xla::Config& config,
|
||||
xla::XlaComputation* computation, absl::string_view debug_info_filename,
|
||||
absl::string_view debug_info_path_begin_marker);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user