Refactor dump_graph to work on WriteableFile internally

The intention is to enable customization of dumping format (as shown by doing custom MLIR format output) so that format can be uniformly changed to another dumping format. `WriteTextProto` is expanded into the internal calls it makes. I didn't aggressively factor out duplication as it read simpler/future customization hooks.

This is rollforward of previous attempt to add MLIR convenience dumper, but avoids duplication of file creation logic & custom main in testing. Also makes it easier to enable in current workflows by linking in & enabling different output.

PiperOrigin-RevId: 306324082
Change-Id: I20c178990faa9ced0301bcc0ef1b333040c38f72
This commit is contained in:
Jacques Pienaar 2020-04-13 15:45:01 -07:00 committed by TensorFlower Gardener
parent 4d847bd055
commit 3c6327be14
7 changed files with 469 additions and 47 deletions

View File

@ -1292,6 +1292,45 @@ tf_cc_test(
], ],
) )
cc_library(
name = "dump_graph",
srcs = ["utils/dump_graph.cc"],
hdrs = ["utils/dump_graph.h"],
deps = [
":convert_graphdef",
":error_util",
":tensorflow",
":tensorflow_dialect_registration",
":tensorflow_passes",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
)
tf_cc_test(
name = "dump_graph_test",
size = "small",
srcs = ["utils/dump_graph_test.cc"],
deps = [
":dump_graph",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:test",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)
cc_library( cc_library(
name = "bridge_logger", name = "bridge_logger",
srcs = ["utils/bridge_logger.cc"], srcs = ["utils/bridge_logger.cc"],

View File

@ -0,0 +1,105 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h"
#include <cstdint>
#include <cstring>
#include <string>
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/Verifier.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
namespace {
// Simple raw_ostream that prints to a file (doesn't take ownership).
struct WritableFileRawStream : public llvm::raw_ostream {
explicit WritableFileRawStream(WritableFile* file) : file(file) {
SetUnbuffered();
}
~WritableFileRawStream() override = default;
uint64_t current_pos() const override { return 0; }
void write_impl(const char* ptr, size_t size) override {
// If an error is encountered, null out the file.
if (file) {
Status s = file->Append(StringPiece(ptr, size));
if (!s.ok()) {
LOG(WARNING) << "Write failed: " << s;
file = nullptr;
}
}
}
// The file being written to.
WritableFile* file;
};
} // namespace
Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph,
const FunctionLibraryDefinition* flib_def,
WritableFile* file) {
WritableFileRawStream os(std::move(file));
mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (flib_def) {
flib_def = &graph.flib_def();
}
auto convert = [&]() -> Status {
mlir::StatusScopedDiagnosticHandler status_handler(&context);
// TODO(jpienaar): Both the graph debug info and import config should be
// specifiable.
GraphDebugInfo debug_info;
GraphImportConfig import_config;
import_config.graph_as_function = true;
import_config.prune_unused_nodes = false;
TF_ASSIGN_OR_RETURN(
module, ConvertGraphToMlir(graph, debug_info,
flib_def ? *flib_def : graph.flib_def(),
import_config, &context));
if (failed(mlir::verify(*module))) {
return status_handler.ConsumeStatus();
}
return status_handler.ConsumeStatus();
};
TF_RETURN_IF_ERROR(convert());
module->print(os, config.op_printing_flags);
return Status::OK();
}
void UseMlirForGraphDump(const MlirDumpConfig& config) {
SetGraphDumper(
[config](const Graph& graph, const FunctionLibraryDefinition* flib_def,
WritableFile* file) -> Status {
return DumpTextualIRToFile(config, graph, flib_def, file);
},
/*suffix=*/".mlir");
}
} // namespace tensorflow

View File

@ -0,0 +1,61 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_GRAPH_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_GRAPH_H_
#include <string>
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
struct MlirDumpConfig;
// Dumps 'graph_def' to a file, as textual IR. Returns the file name chosen.
//
// Note: This is for debugging use and is not optimized for performance.
Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph,
const FunctionLibraryDefinition* flib_def,
WritableFile* file);
// Config of the textual dump.
struct MlirDumpConfig {
// The limit of element size that gets printed.
MlirDumpConfig& elide_large_attributes(int large_element_limit = 16) {
this->op_printing_flags.elideLargeElementsAttrs(large_element_limit);
return *this;
}
// Enable printing of debug information. If 'pretty_form' is set to true,
// debug information is printed in a more readable 'pretty' form but this
// pretty form is not parsable (so only for human readability).
MlirDumpConfig& emit_location_information(bool pretty_form = false) {
this->op_printing_flags.enableDebugInfo(pretty_form);
return *this;
}
// Op printing flags.
mlir::OpPrintingFlags op_printing_flags = llvm::None;
};
// Change DumpGraphToFile to dump MLIR textual IR instead of protobuf.
void UseMlirForGraphDump(const MlirDumpConfig& = {});
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_GRAPH_H_

View File

@ -0,0 +1,96 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
namespace {
void ExpectHasSubstr(const string& s, const string& expected) {
EXPECT_TRUE(absl::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
// WritableFile that simply concats into string.
class StringWritableFile : public WritableFile {
public:
explicit StringWritableFile(string* str) : str_(*str) {}
Status Append(StringPiece data) override {
absl::StrAppend(&str_, data);
return Status::OK();
}
Status Close() override { return Status::OK(); }
Status Flush() override { return Status::OK(); }
Status Name(StringPiece* result) const override {
*result = "(string)";
return Status::OK();
}
Status Sync() override { return Status::OK(); }
Status Tell(int64* position) override {
return errors::Unimplemented("Stream not seekable");
}
private:
string& str_;
};
TEST(Dump, TexualIrToFileSuccess) {
Graph graph(OpRegistry::Global());
Node* node;
TF_CHECK_OK(NodeBuilder("A", "NoOp").Finalize(&graph, &node));
setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1);
UseMlirForGraphDump(MlirDumpConfig());
string ret = DumpGraphToFile("tir", graph);
ASSERT_EQ(ret, io::JoinPath(testing::TmpDir(), "tir.mlir"));
string actual;
TF_ASSERT_OK(ReadFileToString(Env::Default(), ret, &actual));
string expected_substr = R"(tf_executor.island)";
ExpectHasSubstr(actual, expected_substr);
}
TEST(Dump, TexualIrWithOptions) {
Graph graph(OpRegistry::Global());
Node* node;
TF_ASSERT_OK(NodeBuilder("A", "Placeholder")
.Attr("dtype", DT_FLOAT)
.Finalize(&graph, &node));
string actual;
StringWritableFile file(&actual);
TF_ASSERT_OK(DumpTextualIRToFile(MlirDumpConfig().emit_location_information(),
graph, /*flib_def=*/nullptr, &file));
string expected_substr = R"(loc("A"))";
ExpectHasSubstr(actual, expected_substr);
}
} // namespace
} // namespace tensorflow

View File

@ -22,19 +22,22 @@ limitations under the License.
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/strcat.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
using strings::StrCat;
struct NameCounts { struct NameCounts {
mutex counts_mutex; mutex counts_mutex;
std::unordered_map<string, int> counts; std::unordered_map<string, int> counts;
}; };
string MakeUniqueFilename(string name) { string MakeUniqueFilename(string name, const string& suffix = ".pbtxt") {
static NameCounts& instance = *new NameCounts; static NameCounts& instance = *new NameCounts;
// Remove illegal characters from `name`. // Remove illegal characters from `name`.
@ -56,30 +59,67 @@ string MakeUniqueFilename(string name) {
if (count > 0) { if (count > 0) {
absl::StrAppend(&filename, "_", count); absl::StrAppend(&filename, "_", count);
} }
absl::StrAppend(&filename, ".pbtxt"); absl::StrAppend(&filename, suffix);
return filename; return filename;
} }
#if defined(TENSORFLOW_LITE_PROTOS) struct GraphDumperConfig {
Status WriteToFile(const string& filepath, mutex mu;
const ::tensorflow::protobuf::MessageLite& proto) {
string s;
if (!SerializeToStringDeterministic(proto, &s)) {
return errors::Internal("Failed to serialize proto to string.");
}
return WriteStringToFile(Env::Default(), filepath, s);
}
#else
Status WriteToFile(const string& filepath,
const ::tensorflow::protobuf::Message& proto) {
return WriteTextProto(Env::Default(), filepath, proto);
}
#endif
template <class T> // The dumper and suffix configured.
string WriteTextProtoToUniqueFile(Env* env, const string& name, struct Config {
const char* proto_type, T& proto, bool IsSet() const { return dumper != nullptr; }
const string& dirname) { std::function<Status(const Graph& graph,
const FunctionLibraryDefinition* flib_def,
WritableFile*)>
dumper = nullptr;
string suffix = ".pbtxt";
} config TF_GUARDED_BY(mu);
// Returns whether a custom dumper is set.
bool IsSet() TF_LOCKS_EXCLUDED(mu) {
mutex_lock lock(mu);
return config.IsSet();
}
};
GraphDumperConfig& GetGraphDumperConfig() {
static GraphDumperConfig config;
return config;
}
// WritableFile that simply prints to stderr.
class StderrWritableFile : public WritableFile {
public:
StderrWritableFile() {}
Status Append(StringPiece data) override {
fprintf(stderr, "%.*s", static_cast<int>(data.size()), data.data());
return Status::OK();
}
Status Close() override { return Status::OK(); }
Status Flush() override {
fflush(stderr);
return Status::OK();
}
Status Name(StringPiece* result) const override {
*result = "stderr";
return Status::OK();
}
Status Sync() override { return Status::OK(); }
Status Tell(int64* position) override {
return errors::Unimplemented("Stream not seekable");
}
};
Status CreateWritableFile(Env* env, const string& dirname, const string& name,
const string& suffix, string* filepath,
std::unique_ptr<WritableFile>* file) {
string dir; string dir;
if (!dirname.empty()) { if (!dirname.empty()) {
dir = dirname; dir = dirname;
@ -92,7 +132,7 @@ string WriteTextProtoToUniqueFile(Env* env, const string& name,
<< "Failed to dump " << name << " because dump location is not " << "Failed to dump " << name << " because dump location is not "
<< " specified through either TF_DUMP_GRAPH_PREFIX environment " << " specified through either TF_DUMP_GRAPH_PREFIX environment "
<< "variable or function argument."; << "variable or function argument.";
return "(TF_DUMP_GRAPH_PREFIX not specified)"; return errors::InvalidArgument("TF_DUMP_GRAPH_PREFIX not specified");
} }
if (absl::EqualsIgnoreCase(dir, "sponge") || if (absl::EqualsIgnoreCase(dir, "sponge") ||
@ -104,40 +144,96 @@ string WriteTextProtoToUniqueFile(Env* env, const string& name,
} }
} }
string filepath = "NULL"; *filepath = "NULL";
if (dir == "-") { if (dir == "-") {
LOG(INFO) << proto.DebugString(); *file = std::make_unique<StderrWritableFile>();
filepath = "LOG(INFO)"; *filepath = "(stderr)";
} else { return Status::OK();
Status status = env->RecursivelyCreateDir(dir);
if (!status.ok()) {
LOG(WARNING) << "Failed to create " << dir << " for dumping "
<< proto_type << ": " << status;
return "(unavailable)";
} }
filepath = io::JoinPath(dir, MakeUniqueFilename(name));
status = WriteToFile(filepath, proto); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(dir));
if (!status.ok()) { *filepath = io::JoinPath(dir, MakeUniqueFilename(name, suffix));
LOG(WARNING) << "Failed to dump " << proto_type return env->NewWritableFile(*filepath, file);
<< " to file: " << filepath << " : " << status; }
return "(unavailable)";
template <class T>
Status WriteTextProtoToUniqueFile(T& proto, WritableFile* file) {
string s;
#if defined(TENSORFLOW_LITE_PROTOS)
if (!SerializeToStringDeterministic(proto, &s)) {
return errors::Internal("Failed to serialize proto to string.");
} }
#else
if (!::tensorflow::protobuf::TextFormat::PrintToString(proto, &s)) {
return errors::FailedPrecondition("Unable to convert proto to text.");
} }
LOG(INFO) << "Dumped " << proto_type << " to " << filepath; #endif
return filepath; TF_RETURN_IF_ERROR(file->Append(s));
return file->Close();
} }
} // anonymous namespace } // anonymous namespace
void SetGraphDumper(
std::function<Status(const Graph& graph,
const FunctionLibraryDefinition* flib_def,
WritableFile*)>
dumper,
string suffix) {
GraphDumperConfig& dumper_config = GetGraphDumperConfig();
mutex_lock lock(dumper_config.mu);
dumper_config.config.dumper = dumper;
dumper_config.config.suffix = suffix;
}
string DumpGraphDefToFile(const string& name, GraphDef const& graph_def, string DumpGraphDefToFile(const string& name, GraphDef const& graph_def,
const string& dirname) { const string& dirname) {
return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef", graph_def, string filepath;
dirname); std::unique_ptr<WritableFile> file;
Status status = CreateWritableFile(Env::Default(), dirname, name, ".pbtxt",
&filepath, &file);
if (!status.ok()) {
return StrCat("(failed to create writable file: ", status.ToString(), ")");
}
status = WriteTextProtoToUniqueFile(graph_def, file.get());
if (!status.ok()) {
return StrCat("(failed to dump Graph to '", filepath,
"': ", status.ToString(), ")");
}
LOG(INFO) << "Dumped Graph to " << filepath;
return filepath;
} }
string DumpGraphToFile(const string& name, Graph const& graph, string DumpGraphToFile(const string& name, Graph const& graph,
const FunctionLibraryDefinition* flib_def, const FunctionLibraryDefinition* flib_def,
const string& dirname) { const string& dirname) {
auto& dumper_config = GetGraphDumperConfig();
if (dumper_config.IsSet()) {
GraphDumperConfig::Config config;
{
mutex_lock lock(dumper_config.mu);
config = dumper_config.config;
}
if (config.IsSet()) {
string filepath;
std::unique_ptr<WritableFile> file;
Status status = CreateWritableFile(Env::Default(), dirname, name,
config.suffix, &filepath, &file);
if (!status.ok()) {
return StrCat("(failed to create writable file: ", status.ToString(),
")");
}
status = config.dumper(graph, flib_def, file.get());
if (!status.ok()) {
return StrCat("(failed to dump Graph to '", filepath,
"': ", status.ToString(), ")");
}
LOG(INFO) << "Dumped Graph to " << filepath;
return filepath;
}
}
GraphDef graph_def; GraphDef graph_def;
graph.ToGraphDef(&graph_def); graph.ToGraphDef(&graph_def);
if (flib_def) { if (flib_def) {
@ -148,8 +244,21 @@ string DumpGraphToFile(const string& name, Graph const& graph,
string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef, string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef,
const string& dirname) { const string& dirname) {
return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef, string filepath;
dirname); std::unique_ptr<WritableFile> file;
Status status = CreateWritableFile(Env::Default(), dirname, name, ".pbtxt",
&filepath, &file);
if (!status.ok()) {
return StrCat("(failed to create writable file: ", status.ToString(), ")");
}
status = WriteTextProtoToUniqueFile(fdef, file.get());
if (!status.ok()) {
return StrCat("(failed to dump FunctionDef to '", filepath,
"': ", status.ToString(), ")");
}
LOG(INFO) << "Dumped FunctionDef to " << filepath;
return filepath;
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -50,6 +50,16 @@ string DumpGraphToFile(const string& name, Graph const& graph,
string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef, string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef,
const string& dirname = ""); const string& dirname = "");
// Sets a custom Graph dumper. If set, this dumper will be used to dump graphs
// instead via DumpGraphToFile. As the custom dumper may not produce protobufs,
// allow specifying a file suffix/extension too.
void SetGraphDumper(
std::function<Status(const Graph& graph,
const FunctionLibraryDefinition* flib_def,
WritableFile*)>
dumper,
string suffix = ".pbtxt");
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_ #endif // TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_

View File

@ -14,8 +14,10 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/util/dump_graph.h" #include "tensorflow/core/util/dump_graph.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
@ -36,7 +38,7 @@ TEST(DumpGraph, DumpGraphToFileSuccess) {
EXPECT_EQ(ret, io::JoinPath(testing::TmpDir(), "graph_1.pbtxt")); EXPECT_EQ(ret, io::JoinPath(testing::TmpDir(), "graph_1.pbtxt"));
GraphDef gdef; GraphDef gdef;
TF_CHECK_OK(ReadTextProto( TF_ASSERT_OK(ReadTextProto(
Env::Default(), io::JoinPath(testing::TmpDir(), "graph.pbtxt"), &gdef)); Env::Default(), io::JoinPath(testing::TmpDir(), "graph.pbtxt"), &gdef));
string read, written; string read, written;
gdef.AppendToString(&read); gdef.AppendToString(&read);
@ -48,7 +50,7 @@ TEST(DumpGraph, DumpGraphToFileNoEnvPrefix) {
Graph graph(OpRegistry::Global()); Graph graph(OpRegistry::Global());
unsetenv("TF_DUMP_GRAPH_PREFIX"); unsetenv("TF_DUMP_GRAPH_PREFIX");
string ret = DumpGraphToFile("graph", graph); string ret = DumpGraphToFile("graph", graph);
EXPECT_EQ(ret, "(TF_DUMP_GRAPH_PREFIX not specified)"); EXPECT_TRUE(str_util::StrContains(ret, "TF_DUMP_GRAPH_PREFIX not specified"));
} }
TEST(DumpGraph, DumpFunctionDefToFileSuccess) { TEST(DumpGraph, DumpFunctionDefToFileSuccess) {