Support the dump_graphviz_dir
flag
If this flag is specified, the dot files from the functions are written to the directory before and after the transformations. PiperOrigin-RevId: 277326605 Change-Id: I0463f03f657a267c997b1a3671a6ba130c9ec1c2
This commit is contained in:
parent
01182a2f98
commit
2061ec8104
@ -22,11 +22,15 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/lite/toco:model_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:types_proto_cc",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:ViewOpGraph",
|
||||
],
|
||||
)
|
||||
|
@ -17,8 +17,12 @@ limitations under the License.
|
||||
|
||||
#include <ostream>
|
||||
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
|
||||
@ -34,6 +38,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
|
||||
// conversion mapping for constants defined in TFLite Python API.
|
||||
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
@ -78,9 +83,6 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
|
||||
if (model_flags.change_concat_input_ranges()) {
|
||||
LOG(WARNING) << "Ignored change_concat_input_ranges.";
|
||||
}
|
||||
if (toco_flags.dump_graphviz_dir().empty()) {
|
||||
LOG(WARNING) << "Ignored dump_graphviz_dir.";
|
||||
}
|
||||
if (toco_flags.dump_graphviz_include_video()) {
|
||||
LOG(WARNING) << "Ignored dump_graphviz_video.";
|
||||
}
|
||||
@ -89,6 +91,24 @@ void WarningUnusedFlags(const toco::ModelFlags& model_flags,
|
||||
}
|
||||
}
|
||||
|
||||
// Dumps the op graph of the `module` to `filename` in DOT format.
|
||||
Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
||||
std::string error_message;
|
||||
auto output = mlir::openOutputFile(filename, &error_message);
|
||||
if (!error_message.empty()) {
|
||||
return errors::InvalidArgument("Failed to open file in %s.", filename);
|
||||
}
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
|
||||
if (failed(pm.run(module))) {
|
||||
return errors::Unknown("Failed to dump Op Graph from MLIR module.");
|
||||
}
|
||||
output->keep();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
const GraphDebugInfo& debug_info,
|
||||
@ -175,6 +195,13 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
||||
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
module.get(),
|
||||
// rename once we enable the new converter feature flag.
|
||||
absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
|
||||
}
|
||||
|
||||
mlir::PassManager pm(module->getContext());
|
||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
@ -182,10 +209,19 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
|
||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
|
||||
|
||||
return ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, /*emit_quant_adaptor_ops=*/false,
|
||||
/*lower_tensor_list_ops=*/true, quant_specs, result, &pm);
|
||||
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
// rename once we enable the new converter feature flag.
|
||||
module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
|
||||
"/toco_AFTER_TRANSFORMATIONS.dot")));
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -456,7 +456,10 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
graphviz_output = converter.convert()
|
||||
self.assertTrue(graphviz_output)
|
||||
|
||||
def testDumpGraphviz(self):
|
||||
@parameterized.named_parameters(
|
||||
('EnableMlirConverter', True), # enable mlir
|
||||
('DisableMlirConverter', False)) # disable mlir
|
||||
def testDumpGraphviz(self, enable_mlir):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
||||
@ -466,6 +469,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
converter.experimental_new_converter = enable_mlir
|
||||
graphviz_dir = self.get_temp_dir()
|
||||
converter.dump_graphviz_dir = graphviz_dir
|
||||
tflite_model = converter.convert()
|
||||
@ -477,19 +481,26 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
|
||||
num_items_graphviz = len(os.listdir(graphviz_dir))
|
||||
self.assertTrue(num_items_graphviz)
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(graphviz_dir, 'toco_AT_IMPORT.dot')))
|
||||
self.assertTrue(
|
||||
os.path.exists(
|
||||
os.path.join(graphviz_dir, 'toco_AFTER_TRANSFORMATIONS.dot')))
|
||||
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
graphviz_dir = self.get_temp_dir()
|
||||
converter.dump_graphviz_dir = graphviz_dir
|
||||
converter.dump_graphviz_video = True
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
# new converter doesn't support `dump_graphviz_video` flag
|
||||
if not enable_mlir:
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
graphviz_dir = self.get_temp_dir()
|
||||
converter.dump_graphviz_dir = graphviz_dir
|
||||
converter.dump_graphviz_video = True
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Ensure graphviz folder has more data after using video flag.
|
||||
num_items_graphviz_video = len(os.listdir(graphviz_dir))
|
||||
self.assertTrue(num_items_graphviz_video > num_items_graphviz)
|
||||
# Ensure graphviz folder has more data after using video flag.
|
||||
num_items_graphviz_video = len(os.listdir(graphviz_dir))
|
||||
self.assertGreater(num_items_graphviz_video, num_items_graphviz)
|
||||
|
||||
def testInferenceInputType(self):
|
||||
with ops.Graph().as_default():
|
||||
|
Loading…
x
Reference in New Issue
Block a user