Internal change

PiperOrigin-RevId: 299186503
Change-Id: Ic39e151185e5e98636b2b71a9812051c0c357f5c
This commit is contained in:
Bixia Zheng 2020-03-05 13:57:31 -08:00 committed by TensorFlower Gardener
parent 4cedd7e86d
commit 244a9b0a41
15 changed files with 252 additions and 22 deletions

View File

@ -64,6 +64,7 @@ cc_library(
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"//tensorflow/core:regexp_internal",
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),

View File

@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@ -105,8 +106,9 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
.ValueOrDie();
xla::XlaComputation computation;
if (flags.mlir_components == "Bridge") {
TF_RETURN_IF_ERROR(
ConvertGraphDefToXlaViaMlir(graph_def, config, &computation));
TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir(
graph_def, config, &computation, flags.debug_info,
flags.debug_info_path_begin_marker));
} else if (flags.mlir_components.empty() || flags.mlir_components == "None") {
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
client, &computation));
@ -166,6 +168,23 @@ static void InitializeTargets() {
LLVMInitializeX86AsmPrinter();
}
// Replaces {{tag.type tag.name}} in the error message with tag_name.
// TODO(bixia): We currently only handlge tag.type == "node".
//
// In the error message, a graph node is represented as {{tag.type, tag.name}},
// to allow a Python debugger to insert source information about the graph node.
// For example, a Python add expression may be represented as
// {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate
// in tensorflow/python/framework/error_interpolation.py for more detail.
static std::string InterpolateErrorMessage(std::string message) {
// See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py
// Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix".
static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"};
RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3");
return message;
}
Status Main(const MainFlags& flags) {
absl::call_once(targets_init, &InitializeTargets);
@ -192,8 +211,13 @@ Status Main(const MainFlags& flags) {
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(
CompileGraph(std::move(graph_def), config, flags, &compile_result));
Status status =
CompileGraph(std::move(graph_def), config, flags, &compile_result);
if (!status.ok()) {
return Status(status.code(),
InterpolateErrorMessage(status.error_message()));
}
// Write output files.
Env* env = Env::Default();

View File

@ -24,6 +24,13 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
"be in the human-readable proto text format, otherwise it is expected "
"to be in the proto binary format."},
{"debug_info", &flags->debug_info,
"Graph debug info file. If the file ends in '.pbtxt' it is expected to "
"be in the human-readable proto text format, otherwise it is expected "
"to be in the proto binary format."},
{"debug_info_path_begin_marker", &flags->debug_info_path_begin_marker,
"If not none, only keep the file path in the debug information after the"
" marker. The default value is empty"},
{"config", &flags->config,
"Input file containing Config proto. If the file ends in '.pbtxt' it "
"is expected to be in the human-readable proto text format, otherwise "

View File

@ -28,6 +28,8 @@ namespace tfcompile {
struct MainFlags {
string graph;
string debug_info;
string debug_info_path_begin_marker;
string config;
bool dump_fetch_nodes = false;
string target_triple;

View File

@ -1,11 +1,34 @@
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
)
glob_lit_tests(
data = [":filecheck_test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["lit.pbtxt"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "filecheck_test_utilities",
testonly = True,
srcs = [
"test_error_message.lit.pbtxt.config.pbtxt",
"test_error_message.lit.pbtxt.debug.pbtxt",
"test_error_message.lit.pbtxt.fake_py.debug",
],
data = [
"//tensorflow/compiler/aot:tfcompile",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)
# We disable some tfcompile tests in the open source build with the
# "manual" tag to avoid making our OSS users build LLVM twice
# (once for host and once for target).
@ -60,6 +83,7 @@ genrule(
testonly = 1,
outs = [
"test_graph_tfadd.pb",
"test_debuginfo_tfadd.pb",
"test_graph_tfadd_with_ckpt.ckpt",
"test_graph_tfadd_with_ckpt.pb",
"test_graph_tfadd_with_ckpt_saver.ckpt",
@ -317,6 +341,7 @@ tf_library(
testonly = 1,
config = "test_graph_tfadd.config.pbtxt",
cpp_class = "AddComp",
debug_info = "test_debuginfo_tfadd.pb",
graph = "test_graph_tfadd.pb",
include_standard_runtime_deps = False,
mlir_components = "Bridge",

View File

@ -30,6 +30,7 @@ from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -184,7 +185,22 @@ def tfvariable_sequential_updates(_):
array_ops.identity(updates, name='result')
def write_graph(build_graph, out_dir):
def export_debug_info(exported_graph):
"""Exports debug information from a graph.
Args:
exported_graph: A Graph that has been created by tracing a saveable view.
Returns:
Corresponding GraphDebugInfo with traces for all ops in exported_graph.
"""
exported_operations = []
for op in exported_graph.get_operations():
exported_operations.append(('', op))
return error_interpolation.create_graph_debug_info_def(exported_operations)
def write_graph(build_graph, out_dir, debug_info=False):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
with g.as_default():
@ -193,10 +209,19 @@ def write_graph(build_graph, out_dir):
with open(filename, 'wb') as f:
f.write(six.ensure_binary(g.as_graph_def().SerializeToString()))
if debug_info:
filename_debuginfo = os.path.join(
out_dir, 'test_debuginfo_%s.pb' % build_graph.__name__)
test_debuginfo = export_debug_info(g)
with open(filename_debuginfo, 'wb') as f:
f.write(
six.ensure_binary(
test_debuginfo.SerializeToString(deterministic=True)))
def main(_):
control_flow_util.enable_control_flow_v2()
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd, FLAGS.out_dir, debug_info=True)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
write_graph(tfassert_eq, FLAGS.out_dir)

View File

@ -0,0 +1,69 @@
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
# Checks the error message produced by tfcompile with mlir_component
# Checks that source debug information is used in the output error message and
# the node x_y_sum = Add
# CHECK: INVALID ARGUMENTS: Dimensions must be equal, but are 2 and 3 for 'x_y_sum = Add[T=DT_INT32](aot_feed_0/x, aot_feed_0/y)'
# CHECK: math_ops.add(x, y, name='x_y_sum')
# CHECK: build_graph(out_dir)
# Checks the error message produced by tfcompile without mlir_component
# OLD: INVALID ARGUMENTS: Incompatible shapes: [2] vs. [3]
# OLD: x_y_sum
node: {
name: "x"
op: "Placeholder"
attr: {
key: "shape"
value: {
shape: {
dim: {
size: -1
}
}
}
}
attr: {
key: "dtype"
value: {
type: DT_INT32
}
}
}
node: {
name: "y"
op: "Placeholder"
attr: {
key: "shape"
value: {
shape: {
dim: {
size: -1
}
}
}
}
attr: {
key: "dtype"
value: {
type: DT_INT32
}
}
}
node: {
name: "x_y_sum"
op: "Add"
input: "x"
input: "y"
attr: {
key: "T"
value: {
type: DT_INT32
}
}
}
versions: {
producer: 321
}

View File

@ -0,0 +1,16 @@
# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x" }
shape {
dim { size: 2 }
}
}
feed {
id { node_name: "y" }
shape {
dim { size: 3 }
}
}
fetch {
id { node_name: "x_y_sum" }
}

View File

@ -0,0 +1,28 @@
files: "org_tensorflow/tensorflow/compiler/aot/tests/test_error_message.lit.pbtxt.fake_py.debug"
traces: {
key: "x@"
value: {
file_line_cols: {
line: 1
}
}
}
traces: {
key: "x_y_sum@"
value: {
file_line_cols: {
line: 3
}
file_line_cols: {
line: 4
}
}
}
traces: {
key: "y@"
value: {
file_line_cols: {
line: 2
}
}
}

View File

@ -0,0 +1,4 @@
x = value
y = value
math_ops.add(x, y, name='x_y_sum')
build_graph(out_dir)

View File

@ -26,6 +26,7 @@ def tf_library(
name,
graph,
config,
debug_info = None,
freeze_checkpoint = None,
freeze_saver = None,
cpp_class = None,
@ -191,12 +192,15 @@ def tf_library(
mlir_flag = "--mlir_components=" + mlir_components
srcs = [tfcompile_graph, config]
debug_info_flag = ""
if debug_info:
srcs.append(debug_info)
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
native.genrule(
name = ("gen_" + name),
srcs = [
tfcompile_graph,
config,
],
srcs = srcs,
outs = [
header_file,
metadata_object_file,
@ -206,6 +210,7 @@ def tf_library(
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
debug_info_flag +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
@ -237,10 +242,7 @@ def tf_library(
session_module_pb = name + "_session_module.pb"
native.genrule(
name = (name + "_session_module"),
srcs = [
tfcompile_graph,
config,
],
srcs = srcs,
outs = [
session_module_pb,
],
@ -248,6 +250,7 @@ def tf_library(
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
debug_info_flag +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +

View File

@ -65,6 +65,7 @@ int main(int argc, char** argv) {
flags.out_metadata_object = "out_helper.o";
flags.out_header = "out.h";
flags.entry_point = "entry";
flags.debug_info_path_begin_marker = "";
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);

View File

@ -177,8 +177,8 @@ cc_library(
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
],
)

View File

@ -86,9 +86,10 @@ Status ConvertOutputInfo(const tf2xla::Config& config,
} // namespace
Status ConvertGraphDefToXlaViaMlir(GraphDef graph_def,
const tf2xla::Config& config,
xla::XlaComputation* computation) {
Status ConvertGraphDefToXlaViaMlir(
GraphDef graph_def, const tf2xla::Config& config,
xla::XlaComputation* computation, absl::string_view debug_info_filename,
absl::string_view debug_info_path_begin_marker) {
// AddPlaceholdersForFeeds prepares for PruneGraphDefInto and serves two
// purposes: (1) It creates a placeholder node for each feed, so that
// PruneGraphDefInfo can prune away the node containing the feed. (2) It
@ -115,7 +116,24 @@ Status ConvertGraphDefToXlaViaMlir(GraphDef graph_def,
TF_RETURN_IF_ERROR(ConvertOutputInfo(config, &specs));
GraphDebugInfo debug_info;
if (!debug_info_filename.empty()) {
TF_RETURN_IF_ERROR(LoadProtoFromFile(debug_info_filename, &debug_info));
if (!debug_info_path_begin_marker.empty()) {
for (size_t i = 0, e = debug_info.files_size(); i < e; ++i) {
std::string* file_name = debug_info.mutable_files(i);
size_t location =
file_name->rfind(std::string(debug_info_path_begin_marker));
if (location != -1) {
*file_name = file_name->substr(location +
debug_info_path_begin_marker.length());
}
}
}
}
mlir::MLIRContext context;
TF_ASSIGN_OR_RETURN(
mlir::OwningModuleRef module,
ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context));

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@ -34,10 +35,16 @@ Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config,
xla::Client* client,
xla::XlaComputation* computation);
// Similar to ConvertGraphDefToXla, but uses MLIR.
Status ConvertGraphDefToXlaViaMlir(GraphDef graph_def,
const tf2xla::Config& config,
xla::XlaComputation* computation);
// Similar to ConvertGraphDefToXla, but uses MLIR and handle debug information.
//
// debug_info_filename: the file for the debug information proto.
// debug_info_path_begin_marker: if not empty, file pathes in the debug
// information are trimmed from the beginning to the first appearance of the
// marker.
Status ConvertGraphDefToXlaViaMlir(
GraphDef graph_def, const tf2xla::Config& config,
xla::XlaComputation* computation, absl::string_view debug_info_filename,
absl::string_view debug_info_path_begin_marker);
} // namespace tensorflow