Add get_potentially_supported_ops to output TF Lite compatible ops.
PiperOrigin-RevId: 241439699
This commit is contained in:
parent
269e8a35bd
commit
4ee64e012a
25
tensorflow/lite/experimental/tensorboard/BUILD
Normal file
25
tensorflow/lite/experimental/tensorboard/BUILD
Normal file
@ -0,0 +1,25 @@
|
||||
# TFLite modules to support TensorBoard plugin.
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
py_library(
|
||||
name = "ops_util",
|
||||
srcs = ["ops_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/lite/toco/python:tensorflow_wrap_toco",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ops_util_test",
|
||||
srcs = ["ops_util_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ops_util",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
4
tensorflow/lite/experimental/tensorboard/README.md
Normal file
4
tensorflow/lite/experimental/tensorboard/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
This folder contains basic modules to support TFLite plugin for TensorBoard.
|
||||
|
||||
Warning: Everything in this directory is experimental and highly subject to
|
||||
changes.
|
49
tensorflow/lite/experimental/tensorboard/ops_util.py
Normal file
49
tensorflow/lite/experimental/tensorboard/ops_util.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops util to handle ops for Lite."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
from tensorflow.lite.toco.python import tensorflow_wrap_toco
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
class SupportedOp(collections.namedtuple("SupportedOp", ["op"])):
|
||||
"""Spec of supported ops.
|
||||
|
||||
Args:
|
||||
op: string of op name.
|
||||
"""
|
||||
|
||||
|
||||
@tf_export("lite.experimental.get_potentially_supported_ops")
|
||||
def get_potentially_supported_ops():
|
||||
"""Returns operations potentially supported by TensorFlow Lite.
|
||||
|
||||
The potentially support list contains a list of ops that are partially or
|
||||
fully supported, which is derived by simply scanning op names to check whether
|
||||
they can be handled without real conversion and specific parameters.
|
||||
|
||||
Given that some ops may be partially supported, the optimal way to determine
|
||||
if a model's operations are supported is by converting using the TensorFlow
|
||||
Lite converter.
|
||||
|
||||
Returns:
|
||||
A list of SupportedOp.
|
||||
"""
|
||||
ops = tensorflow_wrap_toco.TocoGetPotentiallySupportedOps()
|
||||
return [SupportedOp(o["op"]) for o in ops]
|
39
tensorflow/lite/experimental/tensorboard/ops_util_test.py
Normal file
39
tensorflow/lite/experimental/tensorboard/ops_util_test.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for backend."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.lite.experimental.tensorboard import ops_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class OpsUtilTest(test.TestCase):
|
||||
|
||||
def testGetPotentiallySupportedOps(self):
|
||||
ops = ops_util.get_potentially_supported_ops()
|
||||
# See GetTensorFlowNodeConverterMap() in
|
||||
# tensorflow/lite/toco/import_tensorflow.cc
|
||||
self.assertIsInstance(ops, list)
|
||||
# Test partial ops that surely exist in the list.
|
||||
self.assertIn(ops_util.SupportedOp("Add"), ops)
|
||||
self.assertIn(ops_util.SupportedOp("Log"), ops)
|
||||
self.assertIn(ops_util.SupportedOp("Sigmoid"), ops)
|
||||
self.assertIn(ops_util.SupportedOp("Softmax"), ops)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -73,6 +73,7 @@ py_library(
|
||||
":op_hint",
|
||||
":util",
|
||||
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
|
||||
"//tensorflow/lite/experimental/tensorboard:ops_util",
|
||||
"//tensorflow/lite/python/optimize:calibrator",
|
||||
"//tensorflow/python:graph_util",
|
||||
"//tensorflow/python/keras",
|
||||
|
@ -24,9 +24,11 @@ from six import PY3
|
||||
|
||||
from google.protobuf import text_format as _text_format
|
||||
from google.protobuf.message import DecodeError
|
||||
from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
||||
from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn # pylint: disable=unused-import
|
||||
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell # pylint: disable=unused-import
|
||||
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell # pylint: disable=unused-import
|
||||
from tensorflow.lite.experimental.tensorboard.ops_util import get_potentially_supported_ops # pylint: disable=unused-import
|
||||
from tensorflow.lite.python import lite_constants as constants
|
||||
from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import
|
||||
@ -47,7 +49,6 @@ from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_te
|
||||
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
|
||||
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
|
||||
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
|
||||
from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
||||
from tensorflow.python import keras as _keras
|
||||
from tensorflow.python.client import session as _session
|
||||
from tensorflow.python.eager import def_function as _def_function
|
||||
|
@ -1394,5 +1394,11 @@ class FromKerasFile(test_util.TensorFlowTestCase):
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
|
||||
class ImportOpsUtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testGetPotentiallySupportedOps(self):
|
||||
self.assertIsNotNone(lite.get_potentially_supported_ops())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -2629,4 +2629,16 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
||||
}
|
||||
return ImportTensorFlowGraphDef(model_flags, tf_import_flags, *tf_graph);
|
||||
}
|
||||
|
||||
std::vector<std::string> GetPotentiallySupportedOps() {
|
||||
std::vector<std::string> supported_ops;
|
||||
const internal::ConverterMapType& converter_map =
|
||||
internal::GetTensorFlowNodeConverterMap();
|
||||
|
||||
for (const auto& item : converter_map) {
|
||||
supported_ops.push_back(item.first);
|
||||
}
|
||||
return supported_ops;
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/lite/toco/model.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
|
||||
namespace toco {
|
||||
|
||||
@ -34,14 +34,20 @@ struct TensorFlowImportFlags {
|
||||
bool import_all_ops_as_unsupported = false;
|
||||
};
|
||||
|
||||
// Converts TOCO model from TensorFlow GraphDef with given flags.
|
||||
std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
||||
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
|
||||
const tensorflow::GraphDef& graph_def);
|
||||
|
||||
// Converts TOCO model from the file content of TensorFlow GraphDef with given
|
||||
// flags.
|
||||
std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
||||
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
|
||||
const string& input_file_contents);
|
||||
|
||||
// Gets a list of supported ops by their names.
|
||||
std::vector<std::string> GetPotentiallySupportedOps();
|
||||
|
||||
} // namespace toco
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOCO_IMPORT_TENSORFLOW_H_
|
||||
|
@ -32,4 +32,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||
PyObject* input_contents_txt_raw,
|
||||
bool extended_return = false);
|
||||
|
||||
// Returns a list of names of all ops potentially supported by tflite.
|
||||
PyObject* TocoGetPotentiallySupportedOps();
|
||||
|
||||
} // namespace toco
|
@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||
#include "tensorflow/lite/toco/import_tensorflow.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/python/toco_python_api.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
@ -49,21 +51,32 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||
bool error;
|
||||
std::string model_flags_proto_txt =
|
||||
ConvertArg(model_flags_proto_txt_raw, &error);
|
||||
if (error) return nullptr;
|
||||
if (error) {
|
||||
PyErr_SetString(PyExc_ValueError, "Model flags are invalid.");
|
||||
return nullptr;
|
||||
}
|
||||
std::string toco_flags_proto_txt =
|
||||
ConvertArg(toco_flags_proto_txt_raw, &error);
|
||||
if (error) return nullptr;
|
||||
if (error) {
|
||||
PyErr_SetString(PyExc_ValueError, "Toco flags are invalid.");
|
||||
return nullptr;
|
||||
}
|
||||
std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
|
||||
if (error) return nullptr;
|
||||
if (error) {
|
||||
PyErr_SetString(PyExc_ValueError, "Input GraphDef is invalid.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Use TOCO to produce new outputs.
|
||||
toco::ModelFlags model_flags;
|
||||
if (!model_flags.ParseFromString(model_flags_proto_txt)) {
|
||||
LOG(FATAL) << "Model proto failed to parse." << std::endl;
|
||||
PyErr_SetString(PyExc_ValueError, "Model proto failed to parse.");
|
||||
return nullptr;
|
||||
}
|
||||
toco::TocoFlags toco_flags;
|
||||
if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
|
||||
LOG(FATAL) << "Toco proto failed to parse." << std::endl;
|
||||
PyErr_SetString(PyExc_ValueError, "Toco proto failed to parse.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto& dump_options = *GraphVizDumpOptions::singleton();
|
||||
@ -100,4 +113,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||
output_file_contents_txt.data(), output_file_contents_txt.size());
|
||||
}
|
||||
|
||||
PyObject* TocoGetPotentiallySupportedOps() {
|
||||
std::vector<std::string> supported_ops = toco::GetPotentiallySupportedOps();
|
||||
PyObject* list = PyList_New(supported_ops.size());
|
||||
for (size_t i = 0; i < supported_ops.size(); ++i) {
|
||||
const string& op = supported_ops[i];
|
||||
PyObject* op_dict = PyDict_New();
|
||||
PyDict_SetItemString(op_dict, "op", PyUnicode_FromString(op.c_str()));
|
||||
PyList_SetItem(list, i, op_dict);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -31,6 +31,9 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
||||
PyObject* input_contents_txt_raw,
|
||||
bool extended_return = false);
|
||||
|
||||
// Returns a list of names of all ops potentially supported by tflite.
|
||||
PyObject* TocoGetPotentiallySupportedOps();
|
||||
|
||||
} // namespace toco
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
|
||||
|
@ -125,12 +125,12 @@ def freeze_graph_with_def_protos(input_graph_def,
|
||||
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
|
||||
if (not input_saved_model_dir and
|
||||
not checkpoint_management.checkpoint_exists(input_checkpoint)):
|
||||
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
|
||||
return -1
|
||||
raise ValueError("Input checkpoint '" + input_checkpoint +
|
||||
"' doesn't exist!")
|
||||
|
||||
if not output_node_names:
|
||||
print("You need to supply the name of a node to --output_node_names.")
|
||||
return -1
|
||||
raise ValueError(
|
||||
"You need to supply the name of a node to --output_node_names.")
|
||||
|
||||
# Remove all the explicit device specifications for this node. This helps to
|
||||
# make the graph more portable.
|
||||
@ -193,14 +193,15 @@ def freeze_graph_with_def_protos(input_graph_def,
|
||||
# tensors. Partition variables are Identity tensors that cannot be
|
||||
# handled by Saver.
|
||||
if has_partition_var:
|
||||
print("Models containing partition variables cannot be converted "
|
||||
"from checkpoint files. Please pass in a SavedModel using "
|
||||
"the flag --input_saved_model_dir.")
|
||||
return -1
|
||||
raise ValueError(
|
||||
"Models containing partition variables cannot be converted "
|
||||
"from checkpoint files. Please pass in a SavedModel using "
|
||||
"the flag --input_saved_model_dir.")
|
||||
# Models that have been frozen previously do not contain Variables.
|
||||
elif _has_no_variables(sess):
|
||||
print("No variables were found in this model. It is likely the model "
|
||||
"was frozen previously. You cannot freeze a graph twice.")
|
||||
raise ValueError(
|
||||
"No variables were found in this model. It is likely the model "
|
||||
"was frozen previously. You cannot freeze a graph twice.")
|
||||
return 0
|
||||
else:
|
||||
raise e
|
||||
@ -242,8 +243,7 @@ def freeze_graph_with_def_protos(input_graph_def,
|
||||
def _parse_input_graph_proto(input_graph, input_binary):
|
||||
"""Parses input tensorflow graph into GraphDef proto."""
|
||||
if not gfile.Exists(input_graph):
|
||||
print("Input graph file '" + input_graph + "' does not exist!")
|
||||
return -1
|
||||
raise IOError("Input graph file '" + input_graph + "' does not exist!")
|
||||
input_graph_def = graph_pb2.GraphDef()
|
||||
mode = "rb" if input_binary else "r"
|
||||
with gfile.GFile(input_graph, mode) as f:
|
||||
@ -257,8 +257,7 @@ def _parse_input_graph_proto(input_graph, input_binary):
|
||||
def _parse_input_meta_graph_proto(input_graph, input_binary):
|
||||
"""Parses input tensorflow graph into MetaGraphDef proto."""
|
||||
if not gfile.Exists(input_graph):
|
||||
print("Input meta graph file '" + input_graph + "' does not exist!")
|
||||
return -1
|
||||
raise IOError("Input meta graph file '" + input_graph + "' does not exist!")
|
||||
input_meta_graph_def = MetaGraphDef()
|
||||
mode = "rb" if input_binary else "r"
|
||||
with gfile.GFile(input_graph, mode) as f:
|
||||
@ -273,8 +272,7 @@ def _parse_input_meta_graph_proto(input_graph, input_binary):
|
||||
def _parse_input_saver_proto(input_saver, input_binary):
|
||||
"""Parses input tensorflow Saver into SaverDef proto."""
|
||||
if not gfile.Exists(input_saver):
|
||||
print("Input saver file '" + input_saver + "' does not exist!")
|
||||
return -1
|
||||
raise IOError("Input saver file '" + input_saver + "' does not exist!")
|
||||
mode = "rb" if input_binary else "r"
|
||||
with gfile.GFile(input_saver, mode) as f:
|
||||
saver_def = saver_pb2.SaverDef()
|
||||
@ -369,9 +367,8 @@ def main(unused_args, flags):
|
||||
elif flags.checkpoint_version == 2:
|
||||
checkpoint_version = saver_pb2.SaverDef.V2
|
||||
else:
|
||||
print("Invalid checkpoint version (must be '1' or '2'): %d" %
|
||||
flags.checkpoint_version)
|
||||
return -1
|
||||
raise ValueError("Invalid checkpoint version (must be '1' or '2'): %d" %
|
||||
flags.checkpoint_version)
|
||||
freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
|
||||
flags.input_checkpoint, flags.output_node_names,
|
||||
flags.restore_op_name, flags.filename_tensor_name,
|
||||
@ -380,7 +377,9 @@ def main(unused_args, flags):
|
||||
flags.input_meta_graph, flags.input_saved_model_dir,
|
||||
flags.saved_model_tags, checkpoint_version)
|
||||
|
||||
|
||||
def run_main():
|
||||
"""Main function of freeze_graph."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
@ -487,5 +486,6 @@ def run_main():
|
||||
my_main = lambda unused_args: main(unused_args, flags)
|
||||
app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_main()
|
||||
|
@ -316,17 +316,17 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
||||
output_node_names = "save/restore_all"
|
||||
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
|
||||
|
||||
return_value = freeze_graph.freeze_graph_with_def_protos(
|
||||
input_graph_def=sess.graph_def,
|
||||
input_saver_def=None,
|
||||
input_checkpoint=checkpoint_path,
|
||||
output_node_names=output_node_names,
|
||||
restore_op_name="save/restore_all", # default value
|
||||
filename_tensor_name="save/Const:0", # default value
|
||||
output_graph=output_graph_path,
|
||||
clear_devices=False,
|
||||
initializer_nodes="")
|
||||
self.assertTrue(return_value, -1)
|
||||
with self.assertRaises(ValueError):
|
||||
freeze_graph.freeze_graph_with_def_protos(
|
||||
input_graph_def=sess.graph_def,
|
||||
input_saver_def=None,
|
||||
input_checkpoint=checkpoint_path,
|
||||
output_node_names=output_node_names,
|
||||
restore_op_name="save/restore_all", # default value
|
||||
filename_tensor_name="save/Const:0", # default value
|
||||
output_graph=output_graph_path,
|
||||
clear_devices=False,
|
||||
initializer_nodes="")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -8,4 +8,8 @@ tf_module {
|
||||
name: "convert_op_hints_to_stubs"
|
||||
argspec: "args=[\'session\', \'graph_def\', \'write_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'<function <lambda> instance>\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_potentially_supported_ops"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -8,4 +8,8 @@ tf_module {
|
||||
name: "convert_op_hints_to_stubs"
|
||||
argspec: "args=[\'session\', \'graph_def\', \'write_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'<function <lambda> instance>\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_potentially_supported_ops"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user