Add get_potentially_supported_ops to output TF Lite compatible ops.

PiperOrigin-RevId: 241439699
This commit is contained in:
Tian Lin 2019-04-01 18:07:25 -07:00 committed by TensorFlower Gardener
parent 269e8a35bd
commit 4ee64e012a
16 changed files with 221 additions and 39 deletions

View File

@ -0,0 +1,25 @@
# TFLite modules to support TensorBoard plugin.
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "ops_util",
srcs = ["ops_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/lite/toco/python:tensorflow_wrap_toco",
"//tensorflow/python:util",
],
)
py_test(
name = "ops_util_test",
srcs = ["ops_util_test.py"],
srcs_version = "PY2AND3",
deps = [
":ops_util",
"//tensorflow/python:client_testlib",
],
)

View File

@ -0,0 +1,4 @@
This folder contains basic modules to support TFLite plugin for TensorBoard.
Warning: Everything in this directory is experimental and highly subject to
changes.

View File

@ -0,0 +1,49 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops util to handle ops for Lite."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.lite.toco.python import tensorflow_wrap_toco
from tensorflow.python.util.tf_export import tf_export
class SupportedOp(collections.namedtuple("SupportedOp", ["op"])):
"""Spec of supported ops.
Args:
op: string of op name.
"""
@tf_export("lite.experimental.get_potentially_supported_ops")
def get_potentially_supported_ops():
"""Returns operations potentially supported by TensorFlow Lite.
The potentially support list contains a list of ops that are partially or
fully supported, which is derived by simply scanning op names to check whether
they can be handled without real conversion and specific parameters.
Given that some ops may be partially supported, the optimal way to determine
if a model's operations are supported is by converting using the TensorFlow
Lite converter.
Returns:
A list of SupportedOp.
"""
ops = tensorflow_wrap_toco.TocoGetPotentiallySupportedOps()
return [SupportedOp(o["op"]) for o in ops]

View File

@ -0,0 +1,39 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for backend."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.lite.experimental.tensorboard import ops_util
from tensorflow.python.platform import test
class OpsUtilTest(test.TestCase):
def testGetPotentiallySupportedOps(self):
ops = ops_util.get_potentially_supported_ops()
# See GetTensorFlowNodeConverterMap() in
# tensorflow/lite/toco/import_tensorflow.cc
self.assertIsInstance(ops, list)
# Test partial ops that surely exist in the list.
self.assertIn(ops_util.SupportedOp("Add"), ops)
self.assertIn(ops_util.SupportedOp("Log"), ops)
self.assertIn(ops_util.SupportedOp("Sigmoid"), ops)
self.assertIn(ops_util.SupportedOp("Softmax"), ops)
if __name__ == "__main__":
test.main()

View File

@ -73,6 +73,7 @@ py_library(
":op_hint",
":util",
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
"//tensorflow/lite/experimental/tensorboard:ops_util",
"//tensorflow/lite/python/optimize:calibrator",
"//tensorflow/python:graph_util",
"//tensorflow/python/keras",

View File

@ -24,9 +24,11 @@ from six import PY3
from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn # pylint: disable=unused-import
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell # pylint: disable=unused-import
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell # pylint: disable=unused-import
from tensorflow.lite.experimental.tensorboard.ops_util import get_potentially_supported_ops # pylint: disable=unused-import
from tensorflow.lite.python import lite_constants as constants
from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import
@ -47,7 +49,6 @@ from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_te
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.eager import def_function as _def_function

View File

@ -1394,5 +1394,11 @@ class FromKerasFile(test_util.TensorFlowTestCase):
interpreter.allocate_tensors()
class ImportOpsUtilTest(test_util.TensorFlowTestCase):
def testGetPotentiallySupportedOps(self):
self.assertIsNotNone(lite.get_potentially_supported_ops())
if __name__ == '__main__':
test.main()

View File

@ -2629,4 +2629,16 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
}
return ImportTensorFlowGraphDef(model_flags, tf_import_flags, *tf_graph);
}
std::vector<std::string> GetPotentiallySupportedOps() {
std::vector<std::string> supported_ops;
const internal::ConverterMapType& converter_map =
internal::GetTensorFlowNodeConverterMap();
for (const auto& item : converter_map) {
supported_ops.push_back(item.first);
}
return supported_ops;
}
} // namespace toco

View File

@ -17,9 +17,9 @@ limitations under the License.
#include <memory>
#include <string>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
namespace toco {
@ -34,14 +34,20 @@ struct TensorFlowImportFlags {
bool import_all_ops_as_unsupported = false;
};
// Converts TOCO model from TensorFlow GraphDef with given flags.
std::unique_ptr<Model> ImportTensorFlowGraphDef(
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
const tensorflow::GraphDef& graph_def);
// Converts TOCO model from the file content of TensorFlow GraphDef with given
// flags.
std::unique_ptr<Model> ImportTensorFlowGraphDef(
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
const string& input_file_contents);
// Gets a list of supported ops by their names.
std::vector<std::string> GetPotentiallySupportedOps();
} // namespace toco
#endif // TENSORFLOW_LITE_TOCO_IMPORT_TENSORFLOW_H_

View File

@ -32,4 +32,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
PyObject* input_contents_txt_raw,
bool extended_return = false);
// Returns a list of names of all ops potentially supported by tflite.
PyObject* TocoGetPotentiallySupportedOps();
} // namespace toco

View File

@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <map>
#include <string>
#include <vector>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
#include "tensorflow/lite/toco/import_tensorflow.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/python/toco_python_api.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
@ -49,21 +51,32 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
bool error;
std::string model_flags_proto_txt =
ConvertArg(model_flags_proto_txt_raw, &error);
if (error) return nullptr;
if (error) {
PyErr_SetString(PyExc_ValueError, "Model flags are invalid.");
return nullptr;
}
std::string toco_flags_proto_txt =
ConvertArg(toco_flags_proto_txt_raw, &error);
if (error) return nullptr;
if (error) {
PyErr_SetString(PyExc_ValueError, "Toco flags are invalid.");
return nullptr;
}
std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
if (error) return nullptr;
if (error) {
PyErr_SetString(PyExc_ValueError, "Input GraphDef is invalid.");
return nullptr;
}
// Use TOCO to produce new outputs.
toco::ModelFlags model_flags;
if (!model_flags.ParseFromString(model_flags_proto_txt)) {
LOG(FATAL) << "Model proto failed to parse." << std::endl;
PyErr_SetString(PyExc_ValueError, "Model proto failed to parse.");
return nullptr;
}
toco::TocoFlags toco_flags;
if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
LOG(FATAL) << "Toco proto failed to parse." << std::endl;
PyErr_SetString(PyExc_ValueError, "Toco proto failed to parse.");
return nullptr;
}
auto& dump_options = *GraphVizDumpOptions::singleton();
@ -100,4 +113,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
output_file_contents_txt.data(), output_file_contents_txt.size());
}
PyObject* TocoGetPotentiallySupportedOps() {
std::vector<std::string> supported_ops = toco::GetPotentiallySupportedOps();
PyObject* list = PyList_New(supported_ops.size());
for (size_t i = 0; i < supported_ops.size(); ++i) {
const string& op = supported_ops[i];
PyObject* op_dict = PyDict_New();
PyDict_SetItemString(op_dict, "op", PyUnicode_FromString(op.c_str()));
PyList_SetItem(list, i, op_dict);
}
return list;
}
} // namespace toco

View File

@ -31,6 +31,9 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
PyObject* input_contents_txt_raw,
bool extended_return = false);
// Returns a list of names of all ops potentially supported by tflite.
PyObject* TocoGetPotentiallySupportedOps();
} // namespace toco
#endif // TENSORFLOW_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_

View File

@ -125,12 +125,12 @@ def freeze_graph_with_def_protos(input_graph_def,
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if (not input_saved_model_dir and
not checkpoint_management.checkpoint_exists(input_checkpoint)):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
raise ValueError("Input checkpoint '" + input_checkpoint +
"' doesn't exist!")
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
raise ValueError(
"You need to supply the name of a node to --output_node_names.")
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
@ -193,14 +193,15 @@ def freeze_graph_with_def_protos(input_graph_def,
# tensors. Partition variables are Identity tensors that cannot be
# handled by Saver.
if has_partition_var:
print("Models containing partition variables cannot be converted "
"from checkpoint files. Please pass in a SavedModel using "
"the flag --input_saved_model_dir.")
return -1
raise ValueError(
"Models containing partition variables cannot be converted "
"from checkpoint files. Please pass in a SavedModel using "
"the flag --input_saved_model_dir.")
# Models that have been frozen previously do not contain Variables.
elif _has_no_variables(sess):
print("No variables were found in this model. It is likely the model "
"was frozen previously. You cannot freeze a graph twice.")
raise ValueError(
"No variables were found in this model. It is likely the model "
"was frozen previously. You cannot freeze a graph twice.")
return 0
else:
raise e
@ -242,8 +243,7 @@ def freeze_graph_with_def_protos(input_graph_def,
def _parse_input_graph_proto(input_graph, input_binary):
"""Parses input tensorflow graph into GraphDef proto."""
if not gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
raise IOError("Input graph file '" + input_graph + "' does not exist!")
input_graph_def = graph_pb2.GraphDef()
mode = "rb" if input_binary else "r"
with gfile.GFile(input_graph, mode) as f:
@ -257,8 +257,7 @@ def _parse_input_graph_proto(input_graph, input_binary):
def _parse_input_meta_graph_proto(input_graph, input_binary):
"""Parses input tensorflow graph into MetaGraphDef proto."""
if not gfile.Exists(input_graph):
print("Input meta graph file '" + input_graph + "' does not exist!")
return -1
raise IOError("Input meta graph file '" + input_graph + "' does not exist!")
input_meta_graph_def = MetaGraphDef()
mode = "rb" if input_binary else "r"
with gfile.GFile(input_graph, mode) as f:
@ -273,8 +272,7 @@ def _parse_input_meta_graph_proto(input_graph, input_binary):
def _parse_input_saver_proto(input_saver, input_binary):
"""Parses input tensorflow Saver into SaverDef proto."""
if not gfile.Exists(input_saver):
print("Input saver file '" + input_saver + "' does not exist!")
return -1
raise IOError("Input saver file '" + input_saver + "' does not exist!")
mode = "rb" if input_binary else "r"
with gfile.GFile(input_saver, mode) as f:
saver_def = saver_pb2.SaverDef()
@ -369,9 +367,8 @@ def main(unused_args, flags):
elif flags.checkpoint_version == 2:
checkpoint_version = saver_pb2.SaverDef.V2
else:
print("Invalid checkpoint version (must be '1' or '2'): %d" %
flags.checkpoint_version)
return -1
raise ValueError("Invalid checkpoint version (must be '1' or '2'): %d" %
flags.checkpoint_version)
freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
flags.input_checkpoint, flags.output_node_names,
flags.restore_op_name, flags.filename_tensor_name,
@ -380,7 +377,9 @@ def main(unused_args, flags):
flags.input_meta_graph, flags.input_saved_model_dir,
flags.saved_model_tags, checkpoint_version)
def run_main():
"""Main function of freeze_graph."""
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
@ -487,5 +486,6 @@ def run_main():
my_main = lambda unused_args: main(unused_args, flags)
app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
if __name__ == '__main__':
if __name__ == "__main__":
run_main()

View File

@ -316,17 +316,17 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
output_node_names = "save/restore_all"
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
return_value = freeze_graph.freeze_graph_with_def_protos(
input_graph_def=sess.graph_def,
input_saver_def=None,
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name="save/restore_all", # default value
filename_tensor_name="save/Const:0", # default value
output_graph=output_graph_path,
clear_devices=False,
initializer_nodes="")
self.assertTrue(return_value, -1)
with self.assertRaises(ValueError):
freeze_graph.freeze_graph_with_def_protos(
input_graph_def=sess.graph_def,
input_saver_def=None,
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name="save/restore_all", # default value
filename_tensor_name="save/Const:0", # default value
output_graph=output_graph_path,
clear_devices=False,
initializer_nodes="")
if __name__ == "__main__":

View File

@ -8,4 +8,8 @@ tf_module {
name: "convert_op_hints_to_stubs"
argspec: "args=[\'session\', \'graph_def\', \'write_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'<function <lambda> instance>\'], "
}
member_method {
name: "get_potentially_supported_ops"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -8,4 +8,8 @@ tf_module {
name: "convert_op_hints_to_stubs"
argspec: "args=[\'session\', \'graph_def\', \'write_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'<function <lambda> instance>\'], "
}
member_method {
name: "get_potentially_supported_ops"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
}