Support the dump_graphviz_dir flag

If this flag is specified, the dot files from the functions are written to the
directory before and after the transformations.

PiperOrigin-RevId: 277326605
Change-Id: I0463f03f657a267c997b1a3671a6ba130c9ec1c2
This commit is contained in:
Feng Liu 2019-10-29 11:25:24 -07:00 committed by TensorFlower Gardener
parent 01182a2f98
commit 2061ec8104
3 changed files with 68 additions and 17 deletions

View File

@ -22,11 +22,15 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/lite/toco:model_flags_proto_cc",
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:Support",
"@local_config_mlir//:ViewOpGraph",
],
)

View File

@ -17,8 +17,12 @@ limitations under the License.
#include <ostream>
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
#include "mlir/Transforms/ViewOpGraph.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
@ -34,6 +38,7 @@ limitations under the License.
namespace tensorflow {
namespace {
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
// conversion mapping for constants defined in TFLite Python API.
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
@ -78,9 +83,6 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
if (model_flags.change_concat_input_ranges()) {
LOG(WARNING) << "Ignored change_concat_input_ranges.";
}
if (toco_flags.dump_graphviz_dir().empty()) {
LOG(WARNING) << "Ignored dump_graphviz_dir.";
}
if (toco_flags.dump_graphviz_include_video()) {
LOG(WARNING) << "Ignored dump_graphviz_video.";
}
@ -89,6 +91,24 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
}
}
// Dumps the op graph of the `module` to `filename` in DOT format.
Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
std::string error_message;
auto output = mlir::openOutputFile(filename, &error_message);
if (!error_message.empty()) {
return errors::InvalidArgument("Failed to open file in %s.", filename);
}
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
if (failed(pm.run(module))) {
return errors::Unknown("Failed to dump Op Graph from MLIR module.");
}
output->keep();
return Status::OK();
}
} // namespace
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
const GraphDebugInfo& debug_info,
@ -175,6 +195,13 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
TF_ASSIGN_OR_RETURN(
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
module.get(),
// rename once we enable the new converter feature flag.
absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
}
mlir::PassManager pm(module->getContext());
mlir::TFL::PassConfig pass_config(quant_specs);
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
@ -182,10 +209,19 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
return ConvertTFExecutorToTFLOrFlatbuffer(
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, /*emit_quant_adaptor_ops=*/false,
/*lower_tensor_list_ops=*/true, quant_specs, result, &pm);
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.
module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
"/toco_AFTER_TRANSFORMATIONS.dot")));
}
return status;
}
} // namespace tensorflow

View File

@ -456,7 +456,10 @@ class FromSessionTest(TestModels, parameterized.TestCase):
graphviz_output = converter.convert()
self.assertTrue(graphviz_output)
def testDumpGraphviz(self):
@parameterized.named_parameters(
('EnableMlirConverter', True), # enable mlir
('DisableMlirConverter', False)) # disable mlir
def testDumpGraphviz(self, enable_mlir):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
@ -466,6 +469,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.experimental_new_converter = enable_mlir
graphviz_dir = self.get_temp_dir()
converter.dump_graphviz_dir = graphviz_dir
tflite_model = converter.convert()
@ -477,19 +481,26 @@ class FromSessionTest(TestModels, parameterized.TestCase):
num_items_graphviz = len(os.listdir(graphviz_dir))
self.assertTrue(num_items_graphviz)
self.assertTrue(
os.path.exists(os.path.join(graphviz_dir, 'toco_AT_IMPORT.dot')))
self.assertTrue(
os.path.exists(
os.path.join(graphviz_dir, 'toco_AFTER_TRANSFORMATIONS.dot')))
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
graphviz_dir = self.get_temp_dir()
converter.dump_graphviz_dir = graphviz_dir
converter.dump_graphviz_video = True
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# new converter doesn't support `dump_graphviz_video` flag
if not enable_mlir:
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
graphviz_dir = self.get_temp_dir()
converter.dump_graphviz_dir = graphviz_dir
converter.dump_graphviz_video = True
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Ensure graphviz folder has more data after using video flag.
num_items_graphviz_video = len(os.listdir(graphviz_dir))
self.assertTrue(num_items_graphviz_video > num_items_graphviz)
# Ensure graphviz folder has more data after using video flag.
num_items_graphviz_video = len(os.listdir(graphviz_dir))
self.assertGreater(num_items_graphviz_video, num_items_graphviz)
def testInferenceInputType(self):
with ops.Graph().as_default():