diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index af4235668d8..4305d64c864 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1292,6 +1292,45 @@ tf_cc_test( ], ) +cc_library( + name = "dump_graph", + srcs = ["utils/dump_graph.cc"], + hdrs = ["utils/dump_graph.h"], + deps = [ + ":convert_graphdef", + ":error_util", + ":tensorflow", + ":tensorflow_dialect_registration", + ":tensorflow_passes", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core/platform:logging", + "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +tf_cc_test( + name = "dump_graph_test", + size = "small", + srcs = ["utils/dump_graph_test.cc"], + deps = [ + ":dump_graph", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:test", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "bridge_logger", srcs = ["utils/bridge_logger.cc"], diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc new file mode 100644 index 00000000000..ffcd1f71a50 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h" + +#include <cstdint> +#include <cstring> +#include <string> + +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Analysis/Verifier.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/util/dump_graph.h" + +namespace tensorflow { + +namespace { + +// Simple raw_ostream that prints to a file (doesn't take ownership). +struct WritableFileRawStream : public llvm::raw_ostream { + explicit WritableFileRawStream(WritableFile* file) : file(file) { + SetUnbuffered(); + } + ~WritableFileRawStream() override = default; + uint64_t current_pos() const override { return 0; } + + void write_impl(const char* ptr, size_t size) override { + // If an error is encountered, null out the file. + if (file) { + Status s = file->Append(StringPiece(ptr, size)); + if (!s.ok()) { + LOG(WARNING) << "Write failed: " << s; + file = nullptr; + } + } + } + + // The file being written to. + WritableFile* file; +}; +} // namespace + +Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph, + const FunctionLibraryDefinition* flib_def, + WritableFile* file) { + WritableFileRawStream os(std::move(file)); + mlir::MLIRContext context; + mlir::OwningModuleRef module; + if (flib_def) { + flib_def = &graph.flib_def(); + } + auto convert = [&]() -> Status { + mlir::StatusScopedDiagnosticHandler status_handler(&context); + // TODO(jpienaar): Both the graph debug info and import config should be + // specifiable. + GraphDebugInfo debug_info; + GraphImportConfig import_config; + import_config.graph_as_function = true; + import_config.prune_unused_nodes = false; + TF_ASSIGN_OR_RETURN( + module, ConvertGraphToMlir(graph, debug_info, + flib_def ? *flib_def : graph.flib_def(), + import_config, &context)); + if (failed(mlir::verify(*module))) { + return status_handler.ConsumeStatus(); + } + return status_handler.ConsumeStatus(); + }; + + TF_RETURN_IF_ERROR(convert()); + module->print(os, config.op_printing_flags); + return Status::OK(); +} + +void UseMlirForGraphDump(const MlirDumpConfig& config) { + SetGraphDumper( + [config](const Graph& graph, const FunctionLibraryDefinition* flib_def, + WritableFile* file) -> Status { + return DumpTextualIRToFile(config, graph, flib_def, file); + }, + /*suffix=*/".mlir"); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h new file mode 100644 index 00000000000..b5976420231 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_GRAPH_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_GRAPH_H_ + +#include <string> + +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +struct MlirDumpConfig; + +// Dumps 'graph_def' to a file, as textual IR. Returns the file name chosen. +// +// Note: This is for debugging use and is not optimized for performance. +Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph, + const FunctionLibraryDefinition* flib_def, + WritableFile* file); + +// Config of the textual dump. +struct MlirDumpConfig { + // The limit of element size that gets printed. + MlirDumpConfig& elide_large_attributes(int large_element_limit = 16) { + this->op_printing_flags.elideLargeElementsAttrs(large_element_limit); + return *this; + } + + // Enable printing of debug information. If 'pretty_form' is set to true, + // debug information is printed in a more readable 'pretty' form but this + // pretty form is not parsable (so only for human readability). + MlirDumpConfig& emit_location_information(bool pretty_form = false) { + this->op_printing_flags.enableDebugInfo(pretty_form); + return *this; + } + + // Op printing flags. + mlir::OpPrintingFlags op_printing_flags = llvm::None; +}; + +// Change DumpGraphToFile to dump MLIR textual IR instead of protobuf. +void UseMlirForGraphDump(const MlirDumpConfig& = {}); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_GRAPH_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc new file mode 100644 index 00000000000..5e700fe02a5 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h" + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/dump_graph.h" + +namespace tensorflow { +namespace { + +void ExpectHasSubstr(const string& s, const string& expected) { + EXPECT_TRUE(absl::StrContains(s, expected)) + << "'" << s << "' does not contain '" << expected << "'"; +} + +// WritableFile that simply concats into string. +class StringWritableFile : public WritableFile { + public: + explicit StringWritableFile(string* str) : str_(*str) {} + + Status Append(StringPiece data) override { + absl::StrAppend(&str_, data); + return Status::OK(); + } + + Status Close() override { return Status::OK(); } + + Status Flush() override { return Status::OK(); } + + Status Name(StringPiece* result) const override { + *result = "(string)"; + return Status::OK(); + } + + Status Sync() override { return Status::OK(); } + + Status Tell(int64* position) override { + return errors::Unimplemented("Stream not seekable"); + } + + private: + string& str_; +}; + +TEST(Dump, TexualIrToFileSuccess) { + Graph graph(OpRegistry::Global()); + Node* node; + TF_CHECK_OK(NodeBuilder("A", "NoOp").Finalize(&graph, &node)); + + setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1); + UseMlirForGraphDump(MlirDumpConfig()); + string ret = DumpGraphToFile("tir", graph); + ASSERT_EQ(ret, io::JoinPath(testing::TmpDir(), "tir.mlir")); + + string actual; + TF_ASSERT_OK(ReadFileToString(Env::Default(), ret, &actual)); + string expected_substr = R"(tf_executor.island)"; + ExpectHasSubstr(actual, expected_substr); +} + +TEST(Dump, TexualIrWithOptions) { + Graph graph(OpRegistry::Global()); + Node* node; + TF_ASSERT_OK(NodeBuilder("A", "Placeholder") + .Attr("dtype", DT_FLOAT) + .Finalize(&graph, &node)); + + string actual; + StringWritableFile file(&actual); + TF_ASSERT_OK(DumpTextualIRToFile(MlirDumpConfig().emit_location_information(), + graph, /*flib_def=*/nullptr, &file)); + + string expected_substr = R"(loc("A"))"; + ExpectHasSubstr(actual, expected_substr); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/util/dump_graph.cc b/tensorflow/core/util/dump_graph.cc index 50c149d48a6..dd273af2a00 100644 --- a/tensorflow/core/util/dump_graph.cc +++ b/tensorflow/core/util/dump_graph.cc @@ -22,19 +22,22 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/strcat.h" namespace tensorflow { namespace { +using strings::StrCat; struct NameCounts { mutex counts_mutex; std::unordered_map<string, int> counts; }; -string MakeUniqueFilename(string name) { +string MakeUniqueFilename(string name, const string& suffix = ".pbtxt") { static NameCounts& instance = *new NameCounts; // Remove illegal characters from `name`. @@ -56,30 +59,67 @@ string MakeUniqueFilename(string name) { if (count > 0) { absl::StrAppend(&filename, "_", count); } - absl::StrAppend(&filename, ".pbtxt"); + absl::StrAppend(&filename, suffix); return filename; } -#if defined(TENSORFLOW_LITE_PROTOS) -Status WriteToFile(const string& filepath, - const ::tensorflow::protobuf::MessageLite& proto) { - string s; - if (!SerializeToStringDeterministic(proto, &s)) { - return errors::Internal("Failed to serialize proto to string."); - } - return WriteStringToFile(Env::Default(), filepath, s); -} -#else -Status WriteToFile(const string& filepath, - const ::tensorflow::protobuf::Message& proto) { - return WriteTextProto(Env::Default(), filepath, proto); -} -#endif +struct GraphDumperConfig { + mutex mu; -template <class T> -string WriteTextProtoToUniqueFile(Env* env, const string& name, - const char* proto_type, T& proto, - const string& dirname) { + // The dumper and suffix configured. + struct Config { + bool IsSet() const { return dumper != nullptr; } + std::function<Status(const Graph& graph, + const FunctionLibraryDefinition* flib_def, + WritableFile*)> + dumper = nullptr; + string suffix = ".pbtxt"; + } config TF_GUARDED_BY(mu); + + // Returns whether a custom dumper is set. + bool IsSet() TF_LOCKS_EXCLUDED(mu) { + mutex_lock lock(mu); + return config.IsSet(); + } +}; + +GraphDumperConfig& GetGraphDumperConfig() { + static GraphDumperConfig config; + return config; +} + +// WritableFile that simply prints to stderr. +class StderrWritableFile : public WritableFile { + public: + StderrWritableFile() {} + + Status Append(StringPiece data) override { + fprintf(stderr, "%.*s", static_cast<int>(data.size()), data.data()); + return Status::OK(); + } + + Status Close() override { return Status::OK(); } + + Status Flush() override { + fflush(stderr); + return Status::OK(); + } + + Status Name(StringPiece* result) const override { + *result = "stderr"; + return Status::OK(); + } + + Status Sync() override { return Status::OK(); } + + Status Tell(int64* position) override { + return errors::Unimplemented("Stream not seekable"); + } +}; + +Status CreateWritableFile(Env* env, const string& dirname, const string& name, + const string& suffix, string* filepath, + std::unique_ptr<WritableFile>* file) { string dir; if (!dirname.empty()) { dir = dirname; @@ -92,7 +132,7 @@ string WriteTextProtoToUniqueFile(Env* env, const string& name, << "Failed to dump " << name << " because dump location is not " << " specified through either TF_DUMP_GRAPH_PREFIX environment " << "variable or function argument."; - return "(TF_DUMP_GRAPH_PREFIX not specified)"; + return errors::InvalidArgument("TF_DUMP_GRAPH_PREFIX not specified"); } if (absl::EqualsIgnoreCase(dir, "sponge") || @@ -104,40 +144,96 @@ string WriteTextProtoToUniqueFile(Env* env, const string& name, } } - string filepath = "NULL"; + *filepath = "NULL"; if (dir == "-") { - LOG(INFO) << proto.DebugString(); - filepath = "LOG(INFO)"; - } else { - Status status = env->RecursivelyCreateDir(dir); - if (!status.ok()) { - LOG(WARNING) << "Failed to create " << dir << " for dumping " - << proto_type << ": " << status; - return "(unavailable)"; - } - filepath = io::JoinPath(dir, MakeUniqueFilename(name)); - status = WriteToFile(filepath, proto); - if (!status.ok()) { - LOG(WARNING) << "Failed to dump " << proto_type - << " to file: " << filepath << " : " << status; - return "(unavailable)"; - } + *file = std::make_unique<StderrWritableFile>(); + *filepath = "(stderr)"; + return Status::OK(); } - LOG(INFO) << "Dumped " << proto_type << " to " << filepath; - return filepath; + + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(dir)); + *filepath = io::JoinPath(dir, MakeUniqueFilename(name, suffix)); + return env->NewWritableFile(*filepath, file); +} + +template <class T> +Status WriteTextProtoToUniqueFile(T& proto, WritableFile* file) { + string s; +#if defined(TENSORFLOW_LITE_PROTOS) + if (!SerializeToStringDeterministic(proto, &s)) { + return errors::Internal("Failed to serialize proto to string."); + } +#else + if (!::tensorflow::protobuf::TextFormat::PrintToString(proto, &s)) { + return errors::FailedPrecondition("Unable to convert proto to text."); + } +#endif + TF_RETURN_IF_ERROR(file->Append(s)); + return file->Close(); } } // anonymous namespace +void SetGraphDumper( + std::function<Status(const Graph& graph, + const FunctionLibraryDefinition* flib_def, + WritableFile*)> + dumper, + string suffix) { + GraphDumperConfig& dumper_config = GetGraphDumperConfig(); + mutex_lock lock(dumper_config.mu); + dumper_config.config.dumper = dumper; + dumper_config.config.suffix = suffix; +} + string DumpGraphDefToFile(const string& name, GraphDef const& graph_def, const string& dirname) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef", graph_def, - dirname); + string filepath; + std::unique_ptr<WritableFile> file; + Status status = CreateWritableFile(Env::Default(), dirname, name, ".pbtxt", + &filepath, &file); + if (!status.ok()) { + return StrCat("(failed to create writable file: ", status.ToString(), ")"); + } + + status = WriteTextProtoToUniqueFile(graph_def, file.get()); + if (!status.ok()) { + return StrCat("(failed to dump Graph to '", filepath, + "': ", status.ToString(), ")"); + } + LOG(INFO) << "Dumped Graph to " << filepath; + return filepath; } string DumpGraphToFile(const string& name, Graph const& graph, const FunctionLibraryDefinition* flib_def, const string& dirname) { + auto& dumper_config = GetGraphDumperConfig(); + if (dumper_config.IsSet()) { + GraphDumperConfig::Config config; + { + mutex_lock lock(dumper_config.mu); + config = dumper_config.config; + } + if (config.IsSet()) { + string filepath; + std::unique_ptr<WritableFile> file; + Status status = CreateWritableFile(Env::Default(), dirname, name, + config.suffix, &filepath, &file); + if (!status.ok()) { + return StrCat("(failed to create writable file: ", status.ToString(), + ")"); + } + status = config.dumper(graph, flib_def, file.get()); + if (!status.ok()) { + return StrCat("(failed to dump Graph to '", filepath, + "': ", status.ToString(), ")"); + } + LOG(INFO) << "Dumped Graph to " << filepath; + return filepath; + } + } + GraphDef graph_def; graph.ToGraphDef(&graph_def); if (flib_def) { @@ -148,8 +244,21 @@ string DumpGraphToFile(const string& name, Graph const& graph, string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef, const string& dirname) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef, - dirname); + string filepath; + std::unique_ptr<WritableFile> file; + Status status = CreateWritableFile(Env::Default(), dirname, name, ".pbtxt", + &filepath, &file); + if (!status.ok()) { + return StrCat("(failed to create writable file: ", status.ToString(), ")"); + } + + status = WriteTextProtoToUniqueFile(fdef, file.get()); + if (!status.ok()) { + return StrCat("(failed to dump FunctionDef to '", filepath, + "': ", status.ToString(), ")"); + } + LOG(INFO) << "Dumped FunctionDef to " << filepath; + return filepath; } } // namespace tensorflow diff --git a/tensorflow/core/util/dump_graph.h b/tensorflow/core/util/dump_graph.h index fe64d1766ea..e4428bd0206 100644 --- a/tensorflow/core/util/dump_graph.h +++ b/tensorflow/core/util/dump_graph.h @@ -50,6 +50,16 @@ string DumpGraphToFile(const string& name, Graph const& graph, string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef, const string& dirname = ""); +// Sets a custom Graph dumper. If set, this dumper will be used to dump graphs +// instead via DumpGraphToFile. As the custom dumper may not produce protobufs, +// allow specifying a file suffix/extension too. +void SetGraphDumper( + std::function<Status(const Graph& graph, + const FunctionLibraryDefinition* flib_def, + WritableFile*)> + dumper, + string suffix = ".pbtxt"); + } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_ diff --git a/tensorflow/core/util/dump_graph_test.cc b/tensorflow/core/util/dump_graph_test.cc index d01c1c5a029..c7e510e05c5 100644 --- a/tensorflow/core/util/dump_graph_test.cc +++ b/tensorflow/core/util/dump_graph_test.cc @@ -14,8 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/util/dump_graph.h" + #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" @@ -36,7 +38,7 @@ TEST(DumpGraph, DumpGraphToFileSuccess) { EXPECT_EQ(ret, io::JoinPath(testing::TmpDir(), "graph_1.pbtxt")); GraphDef gdef; - TF_CHECK_OK(ReadTextProto( + TF_ASSERT_OK(ReadTextProto( Env::Default(), io::JoinPath(testing::TmpDir(), "graph.pbtxt"), &gdef)); string read, written; gdef.AppendToString(&read); @@ -48,7 +50,7 @@ TEST(DumpGraph, DumpGraphToFileNoEnvPrefix) { Graph graph(OpRegistry::Global()); unsetenv("TF_DUMP_GRAPH_PREFIX"); string ret = DumpGraphToFile("graph", graph); - EXPECT_EQ(ret, "(TF_DUMP_GRAPH_PREFIX not specified)"); + EXPECT_TRUE(str_util::StrContains(ret, "TF_DUMP_GRAPH_PREFIX not specified")); } TEST(DumpGraph, DumpFunctionDefToFileSuccess) {