From 8ff33271ea4de89e6ff662fe8e479c1fcf56fe77 Mon Sep 17 00:00:00 2001 From: Justin Lebar <jlebar@google.com> Date: Fri, 20 Oct 2017 15:55:57 -0700 Subject: [PATCH] Dump the computation's SessionModule as part of the tf_compile rule. PiperOrigin-RevId: 172946149 --- tensorflow/compiler/aot/compile.cc | 6 +++--- tensorflow/compiler/aot/flags.cc | 5 ++--- tensorflow/compiler/aot/flags.h | 2 +- tensorflow/compiler/aot/tfcompile.bzl | 3 +++ 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index eac8da0ab1b..77c4ec88cbe 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -97,11 +97,11 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client, &computation, &compile_result->has_context_arg)); - if (!flags.debug_dir.empty()) { + if (!flags.out_session_module.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module, computation.Snapshot()); - string file = io::JoinPath(flags.debug_dir, "tfcompile_xla_module.pb"); - TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), file, *module)); + TF_RETURN_IF_ERROR( + WriteBinaryProto(Env::Default(), flags.out_session_module, *module)); } xla::cpu::CpuAotCompilationOptions aot_opts( flags.target_triple, flags.target_cpu, flags.target_features, diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc index 5aff10346fa..7c2f27e550d 100644 --- a/tensorflow/compiler/aot/flags.cc +++ b/tensorflow/compiler/aot/flags.cc @@ -33,9 +33,6 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) { "fetch nodes will be dumped to stdout in a comma-separated list. " "Typically used to format arguments for other tools, e.g. " "freeze_graph."}, - {"debug_dir", &flags->debug_dir, - "Specifies a directory to dump debugging information, including " - "rewritten graphs and the XLA HLO module."}, // Flags controlling the XLA ahead-of-time compilation, that correspond to // the fields of xla::cpu::CpuAotCompilationOptions. // @@ -64,6 +61,8 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) { "namespaces are given, within the global namespace."}, {"out_object", &flags->out_object, "Output object file name."}, {"out_header", &flags->out_header, "Output header file name."}, + {"out_session_module", &flags->out_session_module, + "Output session module proto."}, {"gen_name_to_index", &flags->gen_name_to_index, "Generate name-to-index data for Lookup{Arg,Result}Index methods."}, {"gen_program_shape", &flags->gen_program_shape, diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h index 3246dbf95c8..3519659e3af 100644 --- a/tensorflow/compiler/aot/flags.h +++ b/tensorflow/compiler/aot/flags.h @@ -29,7 +29,6 @@ struct MainFlags { string graph; string config; bool dump_fetch_nodes = false; - string debug_dir; string target_triple; string target_cpu; string target_features; @@ -37,6 +36,7 @@ struct MainFlags { string cpp_class; string out_object; string out_header; + string out_session_module; // C++ codegen options bool gen_name_to_index = false; diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 4888760acd4..0ecfbedcb42 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -129,6 +129,7 @@ def tf_library(name, graph, config, # Rule that runs tfcompile to produce the header and object file. header_file = name + ".h" object_file = name + ".o" + session_module_pb = name + "_session_module.pb" ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_") native.genrule( name=("gen_" + name), @@ -139,6 +140,7 @@ def tf_library(name, graph, config, outs=[ header_file, object_file, + session_module_pb, ], cmd=("$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + @@ -148,6 +150,7 @@ def tf_library(name, graph, config, " --target_triple=" + target_llvm_triple() + " --out_header=$(@D)/" + header_file + " --out_object=$(@D)/" + object_file + + " --out_session_module=$(@D)/" + session_module_pb + " " + (tfcompile_flags or "")), tools=[tfcompile_tool], visibility=visibility,