Support the dump_graphviz_dir flag

If this flag is specified, the dot files from the functions are written to the
directory before and after the transformations.

PiperOrigin-RevId: 277326605
Change-Id: I0463f03f657a267c997b1a3671a6ba130c9ec1c2
This commit is contained in:
Feng Liu 2019-10-29 11:25:24 -07:00 committed by TensorFlower Gardener
parent 01182a2f98
commit 2061ec8104
3 changed files with 68 additions and 17 deletions

View File

@ -22,11 +22,15 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:lib_proto_parsing", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc", "//tensorflow/core:protos_all_proto_cc",
"//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:model_flags_proto_cc",
"//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc", "//tensorflow/lite/toco:types_proto_cc",
"@llvm//:support",
"@local_config_mlir//:IR", "@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:Support",
"@local_config_mlir//:ViewOpGraph",
], ],
) )

View File

@ -17,8 +17,12 @@ limitations under the License.
#include <ostream> #include <ostream>
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
#include "mlir/Transforms/ViewOpGraph.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
@ -34,6 +38,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace {
// Converts the toco::IODataType to tensorflow::DataType. Only contains the // Converts the toco::IODataType to tensorflow::DataType. Only contains the
// conversion mapping for constants defined in TFLite Python API. // conversion mapping for constants defined in TFLite Python API.
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) { DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
@ -78,9 +83,6 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
if (model_flags.change_concat_input_ranges()) { if (model_flags.change_concat_input_ranges()) {
LOG(WARNING) << "Ignored change_concat_input_ranges."; LOG(WARNING) << "Ignored change_concat_input_ranges.";
} }
if (toco_flags.dump_graphviz_dir().empty()) {
LOG(WARNING) << "Ignored dump_graphviz_dir.";
}
if (toco_flags.dump_graphviz_include_video()) { if (toco_flags.dump_graphviz_include_video()) {
LOG(WARNING) << "Ignored dump_graphviz_video."; LOG(WARNING) << "Ignored dump_graphviz_video.";
} }
@ -89,6 +91,24 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
} }
} }
// Dumps the op graph of the `module` to `filename` in DOT format.
Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
std::string error_message;
auto output = mlir::openOutputFile(filename, &error_message);
if (!error_message.empty()) {
return errors::InvalidArgument("Failed to open file in %s.", filename);
}
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
if (failed(pm.run(module))) {
return errors::Unknown("Failed to dump Op Graph from MLIR module.");
}
output->keep();
return Status::OK();
}
} // namespace
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags, const toco::TocoFlags& toco_flags,
const GraphDebugInfo& debug_info, const GraphDebugInfo& debug_info,
@ -175,6 +195,13 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context)); auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
module.get(),
// rename once we enable the new converter feature flag.
absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
}
mlir::PassManager pm(module->getContext()); mlir::PassManager pm(module->getContext());
mlir::TFL::PassConfig pass_config(quant_specs); mlir::TFL::PassConfig pass_config(quant_specs);
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
@ -182,10 +209,19 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
return ConvertTFExecutorToTFLOrFlatbuffer( auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, /*emit_quant_adaptor_ops=*/false, emit_select_tf_ops, emit_custom_ops, /*emit_quant_adaptor_ops=*/false,
/*lower_tensor_list_ops=*/true, quant_specs, result, &pm); /*lower_tensor_list_ops=*/true, quant_specs, result, &pm);
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.
module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
"/toco_AFTER_TRANSFORMATIONS.dot")));
}
return status;
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -456,7 +456,10 @@ class FromSessionTest(TestModels, parameterized.TestCase):
graphviz_output = converter.convert() graphviz_output = converter.convert()
self.assertTrue(graphviz_output) self.assertTrue(graphviz_output)
def testDumpGraphviz(self): @parameterized.named_parameters(
('EnableMlirConverter', True), # enable mlir
('DisableMlirConverter', False)) # disable mlir
def testDumpGraphviz(self, enable_mlir):
with ops.Graph().as_default(): with ops.Graph().as_default():
in_tensor = array_ops.placeholder( in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32) shape=[1, 16, 16, 3], dtype=dtypes.float32)
@ -466,6 +469,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
# Convert model and ensure model is not None. # Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor], converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor]) [out_tensor])
converter.experimental_new_converter = enable_mlir
graphviz_dir = self.get_temp_dir() graphviz_dir = self.get_temp_dir()
converter.dump_graphviz_dir = graphviz_dir converter.dump_graphviz_dir = graphviz_dir
tflite_model = converter.convert() tflite_model = converter.convert()
@ -477,19 +481,26 @@ class FromSessionTest(TestModels, parameterized.TestCase):
num_items_graphviz = len(os.listdir(graphviz_dir)) num_items_graphviz = len(os.listdir(graphviz_dir))
self.assertTrue(num_items_graphviz) self.assertTrue(num_items_graphviz)
self.assertTrue(
os.path.exists(os.path.join(graphviz_dir, 'toco_AT_IMPORT.dot')))
self.assertTrue(
os.path.exists(
os.path.join(graphviz_dir, 'toco_AFTER_TRANSFORMATIONS.dot')))
# Convert model and ensure model is not None. # new converter doesn't support `dump_graphviz_video` flag
converter = lite.TFLiteConverter.from_session(sess, [in_tensor], if not enable_mlir:
[out_tensor]) # Convert model and ensure model is not None.
graphviz_dir = self.get_temp_dir() converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
converter.dump_graphviz_dir = graphviz_dir [out_tensor])
converter.dump_graphviz_video = True graphviz_dir = self.get_temp_dir()
tflite_model = converter.convert() converter.dump_graphviz_dir = graphviz_dir
self.assertTrue(tflite_model) converter.dump_graphviz_video = True
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Ensure graphviz folder has more data after using video flag. # Ensure graphviz folder has more data after using video flag.
num_items_graphviz_video = len(os.listdir(graphviz_dir)) num_items_graphviz_video = len(os.listdir(graphviz_dir))
self.assertTrue(num_items_graphviz_video > num_items_graphviz) self.assertGreater(num_items_graphviz_video, num_items_graphviz)
def testInferenceInputType(self): def testInferenceInputType(self):
with ops.Graph().as_default(): with ops.Graph().as_default():