Automated rollback of commit 4ee64e012a

PiperOrigin-RevId: 244798543
This commit is contained in:
Tian Lin 2019-04-22 22:57:37 -07:00 committed by TensorFlower Gardener
parent 354b95f958
commit f0fba04cd2
18 changed files with 275 additions and 50 deletions

View File

@ -0,0 +1,25 @@
# TFLite modules to support TensorBoard plugin.
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "ops_util",
srcs = ["ops_util.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/lite/python:wrap_toco",
"//tensorflow/python:util",
],
)
py_test(
name = "ops_util_test",
srcs = ["ops_util_test.py"],
srcs_version = "PY2AND3",
deps = [
":ops_util",
"//tensorflow/python:client_testlib",
],
)

View File

@ -0,0 +1,4 @@
This folder contains basic modules to support TFLite plugin for TensorBoard.
Warning: Everything in this directory is experimental and highly subject to
changes.

View File

@ -0,0 +1,50 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops util to handle ops for Lite."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.lite.python import wrap_toco
from tensorflow.python.util.tf_export import tf_export
class SupportedOp(collections.namedtuple("SupportedOp", ["op"])):
"""Spec of supported ops.
Args:
op: string of op name.
"""
@tf_export(v1=["lite.experimental.get_potentially_supported_ops"])
def get_potentially_supported_ops():
"""Returns operations potentially supported by TensorFlow Lite.
The potentially support list contains a list of ops that are partially or
fully supported, which is derived by simply scanning op names to check whether
they can be handled without real conversion and specific parameters.
Given that some ops may be partially supported, the optimal way to determine
if a model's operations are supported is by converting using the TensorFlow
Lite converter.
Returns:
A list of SupportedOp.
"""
ops = wrap_toco.wrapped_get_potentially_supported_ops()
return [SupportedOp(o["op"]) for o in ops]

View File

@ -0,0 +1,39 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for backend."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.lite.experimental.tensorboard import ops_util
from tensorflow.python.platform import test
class OpsUtilTest(test.TestCase):
def testGetPotentiallySupportedOps(self):
ops = ops_util.get_potentially_supported_ops()
# See GetTensorFlowNodeConverterMap() in
# tensorflow/lite/toco/import_tensorflow.cc
self.assertIsInstance(ops, list)
# Test partial ops that surely exist in the list.
self.assertIn(ops_util.SupportedOp("Add"), ops)
self.assertIn(ops_util.SupportedOp("Log"), ops)
self.assertIn(ops_util.SupportedOp("Sigmoid"), ops)
self.assertIn(ops_util.SupportedOp("Softmax"), ops)
if __name__ == "__main__":
test.main()

View File

@ -73,6 +73,7 @@ py_library(
":op_hint",
":util",
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
"//tensorflow/lite/experimental/tensorboard:ops_util",
"//tensorflow/lite/python/optimize:calibrator",
"//tensorflow/python:graph_util",
"//tensorflow/python/keras",
@ -155,6 +156,17 @@ py_test(
],
)
py_library(
name = "wrap_toco",
srcs = [
"wrap_toco.py",
],
deps = [
"//tensorflow/lite/toco/python:tensorflow_wrap_toco",
"//tensorflow/python:util",
],
)
py_library(
name = "lite_constants",
srcs = ["lite_constants.py"],
@ -173,9 +185,9 @@ py_library(
deps = [
":lite_constants",
":util",
":wrap_toco",
"//tensorflow/lite/toco:model_flags_proto_py",
"//tensorflow/lite/toco:toco_flags_proto_py",
"//tensorflow/lite/toco/python:tensorflow_wrap_toco",
"//tensorflow/lite/toco/python:toco_from_protos",
"//tensorflow/python:dtypes",
"//tensorflow/python:platform",

View File

@ -27,21 +27,14 @@ import tempfile as _tempfile
from tensorflow.lite.python import lite_constants
from tensorflow.lite.python import util
from tensorflow.lite.python import wrap_toco
from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
from tensorflow.lite.toco import types_pb2 as _types_pb2
from tensorflow.python.platform import resource_loader as _resource_loader
from tensorflow.python.util import deprecation
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export as _tf_export
# Lazy load since some of the performance benchmark skylark rules
# break dependencies.
_toco_python = LazyLoader(
"tensorflow_wrap_toco", globals(),
"tensorflow.lite.toco.python."
"tensorflow_wrap_toco")
del LazyLoader
# Find the toco_from_protos binary using the resource loader if using from
# bazel, otherwise we are in a pip where console_scripts already has
@ -119,8 +112,8 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
# switch this on.
if not _toco_from_proto_bin:
try:
model_str = _toco_python.TocoConvert(model_flags_str, toco_flags_str,
input_data_str)
model_str = wrap_toco.wrapped_toco_convert(model_flags_str,
toco_flags_str, input_data_str)
return model_str
except Exception as e:
raise ConverterError("TOCO failed: %s" % e)

View File

@ -24,9 +24,11 @@ from six import PY3
from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn # pylint: disable=unused-import
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell # pylint: disable=unused-import
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell # pylint: disable=unused-import
from tensorflow.lite.experimental.tensorboard.ops_util import get_potentially_supported_ops # pylint: disable=unused-import
from tensorflow.lite.python import lite_constants as constants
from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import
@ -47,7 +49,6 @@ from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_te
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.eager import context

View File

@ -1568,5 +1568,11 @@ class FromKerasFile(test_util.TensorFlowTestCase):
self.assertTrue(tflite_model)
class ImportOpsUtilTest(test_util.TensorFlowTestCase):
def testGetPotentiallySupportedOps(self):
self.assertIsNotNone(lite.get_potentially_supported_ops())
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,40 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wraps toco interface with python lazy loader."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.util.lazy_loader import LazyLoader
# TODO(b/131123224): Lazy load since some of the performance benchmark skylark
# rules and monolithic build break dependencies.
_toco_python = LazyLoader(
"tensorflow_wrap_toco", globals(),
"tensorflow.lite.toco.python."
"tensorflow_wrap_toco")
del LazyLoader
def wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str):
"""Wraps TocoConvert with lazy loader."""
return _toco_python.TocoConvert(model_flags_str, toco_flags_str,
input_data_str)
def wrapped_get_potentially_supported_ops():
"""Wraps TocoGetPotentiallySupportedOps with lazy loader."""
return _toco_python.TocoGetPotentiallySupportedOps()

View File

@ -2633,4 +2633,16 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
}
return ImportTensorFlowGraphDef(model_flags, tf_import_flags, *tf_graph);
}
std::vector<std::string> GetPotentiallySupportedOps() {
std::vector<std::string> supported_ops;
const internal::ConverterMapType& converter_map =
internal::GetTensorFlowNodeConverterMap();
for (const auto& item : converter_map) {
supported_ops.push_back(item.first);
}
return supported_ops;
}
} // namespace toco

View File

@ -17,9 +17,9 @@ limitations under the License.
#include <memory>
#include <string>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
namespace toco {
@ -34,14 +34,20 @@ struct TensorFlowImportFlags {
bool import_all_ops_as_unsupported = false;
};
// Converts TOCO model from TensorFlow GraphDef with given flags.
std::unique_ptr<Model> ImportTensorFlowGraphDef(
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
const tensorflow::GraphDef& graph_def);
// Converts TOCO model from the file content of TensorFlow GraphDef with given
// flags.
std::unique_ptr<Model> ImportTensorFlowGraphDef(
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
const string& input_file_contents);
// Gets a list of supported ops by their names.
std::vector<std::string> GetPotentiallySupportedOps();
} // namespace toco
#endif // TENSORFLOW_LITE_TOCO_IMPORT_TENSORFLOW_H_

View File

@ -32,4 +32,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
PyObject* input_contents_txt_raw,
bool extended_return = false);
// Returns a list of names of all ops potentially supported by tflite.
PyObject* TocoGetPotentiallySupportedOps();
} // namespace toco

View File

@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <map>
#include <string>
#include <vector>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
#include "tensorflow/lite/toco/import_tensorflow.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/python/toco_python_api.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
@ -49,21 +51,32 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
bool error;
std::string model_flags_proto_txt =
ConvertArg(model_flags_proto_txt_raw, &error);
if (error) return nullptr;
if (error) {
PyErr_SetString(PyExc_ValueError, "Model flags are invalid.");
return nullptr;
}
std::string toco_flags_proto_txt =
ConvertArg(toco_flags_proto_txt_raw, &error);
if (error) return nullptr;
if (error) {
PyErr_SetString(PyExc_ValueError, "Toco flags are invalid.");
return nullptr;
}
std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
if (error) return nullptr;
if (error) {
PyErr_SetString(PyExc_ValueError, "Input GraphDef is invalid.");
return nullptr;
}
// Use TOCO to produce new outputs.
toco::ModelFlags model_flags;
if (!model_flags.ParseFromString(model_flags_proto_txt)) {
LOG(FATAL) << "Model proto failed to parse." << std::endl;
PyErr_SetString(PyExc_ValueError, "Model proto failed to parse.");
return nullptr;
}
toco::TocoFlags toco_flags;
if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
LOG(FATAL) << "Toco proto failed to parse." << std::endl;
PyErr_SetString(PyExc_ValueError, "Toco proto failed to parse.");
return nullptr;
}
auto& dump_options = *GraphVizDumpOptions::singleton();
@ -100,4 +113,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
output_file_contents_txt.data(), output_file_contents_txt.size());
}
PyObject* TocoGetPotentiallySupportedOps() {
std::vector<std::string> supported_ops = toco::GetPotentiallySupportedOps();
PyObject* list = PyList_New(supported_ops.size());
for (size_t i = 0; i < supported_ops.size(); ++i) {
const string& op = supported_ops[i];
PyObject* op_dict = PyDict_New();
PyDict_SetItemString(op_dict, "op", PyUnicode_FromString(op.c_str()));
PyList_SetItem(list, i, op_dict);
}
return list;
}
} // namespace toco

View File

@ -31,6 +31,9 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
PyObject* input_contents_txt_raw,
bool extended_return = false);
// Returns a list of names of all ops potentially supported by tflite.
PyObject* TocoGetPotentiallySupportedOps();
} // namespace toco
#endif // TENSORFLOW_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_

View File

@ -125,12 +125,12 @@ def freeze_graph_with_def_protos(input_graph_def,
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if (not input_saved_model_dir and
not checkpoint_management.checkpoint_exists(input_checkpoint)):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
raise ValueError("Input checkpoint '" + input_checkpoint +
"' doesn't exist!")
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
raise ValueError(
"You need to supply the name of a node to --output_node_names.")
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
@ -193,14 +193,15 @@ def freeze_graph_with_def_protos(input_graph_def,
# tensors. Partition variables are Identity tensors that cannot be
# handled by Saver.
if has_partition_var:
print("Models containing partition variables cannot be converted "
"from checkpoint files. Please pass in a SavedModel using "
"the flag --input_saved_model_dir.")
return -1
raise ValueError(
"Models containing partition variables cannot be converted "
"from checkpoint files. Please pass in a SavedModel using "
"the flag --input_saved_model_dir.")
# Models that have been frozen previously do not contain Variables.
elif _has_no_variables(sess):
print("No variables were found in this model. It is likely the model "
"was frozen previously. You cannot freeze a graph twice.")
raise ValueError(
"No variables were found in this model. It is likely the model "
"was frozen previously. You cannot freeze a graph twice.")
return 0
else:
raise e
@ -242,8 +243,7 @@ def freeze_graph_with_def_protos(input_graph_def,
def _parse_input_graph_proto(input_graph, input_binary):
"""Parses input tensorflow graph into GraphDef proto."""
if not gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
raise IOError("Input graph file '" + input_graph + "' does not exist!")
input_graph_def = graph_pb2.GraphDef()
mode = "rb" if input_binary else "r"
with gfile.GFile(input_graph, mode) as f:
@ -257,8 +257,7 @@ def _parse_input_graph_proto(input_graph, input_binary):
def _parse_input_meta_graph_proto(input_graph, input_binary):
"""Parses input tensorflow graph into MetaGraphDef proto."""
if not gfile.Exists(input_graph):
print("Input meta graph file '" + input_graph + "' does not exist!")
return -1
raise IOError("Input meta graph file '" + input_graph + "' does not exist!")
input_meta_graph_def = MetaGraphDef()
mode = "rb" if input_binary else "r"
with gfile.GFile(input_graph, mode) as f:
@ -273,8 +272,7 @@ def _parse_input_meta_graph_proto(input_graph, input_binary):
def _parse_input_saver_proto(input_saver, input_binary):
"""Parses input tensorflow Saver into SaverDef proto."""
if not gfile.Exists(input_saver):
print("Input saver file '" + input_saver + "' does not exist!")
return -1
raise IOError("Input saver file '" + input_saver + "' does not exist!")
mode = "rb" if input_binary else "r"
with gfile.GFile(input_saver, mode) as f:
saver_def = saver_pb2.SaverDef()
@ -369,9 +367,8 @@ def main(unused_args, flags):
elif flags.checkpoint_version == 2:
checkpoint_version = saver_pb2.SaverDef.V2
else:
print("Invalid checkpoint version (must be '1' or '2'): %d" %
flags.checkpoint_version)
return -1
raise ValueError("Invalid checkpoint version (must be '1' or '2'): %d" %
flags.checkpoint_version)
freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
flags.input_checkpoint, flags.output_node_names,
flags.restore_op_name, flags.filename_tensor_name,
@ -380,7 +377,9 @@ def main(unused_args, flags):
flags.input_meta_graph, flags.input_saved_model_dir,
flags.saved_model_tags, checkpoint_version)
def run_main():
"""Main function of freeze_graph."""
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
@ -487,5 +486,6 @@ def run_main():
my_main = lambda unused_args: main(unused_args, flags)
app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
if __name__ == '__main__':
if __name__ == "__main__":
run_main()

View File

@ -316,17 +316,17 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
output_node_names = "save/restore_all"
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
return_value = freeze_graph.freeze_graph_with_def_protos(
input_graph_def=sess.graph_def,
input_saver_def=None,
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name="save/restore_all", # default value
filename_tensor_name="save/Const:0", # default value
output_graph=output_graph_path,
clear_devices=False,
initializer_nodes="")
self.assertTrue(return_value, -1)
with self.assertRaises(ValueError):
freeze_graph.freeze_graph_with_def_protos(
input_graph_def=sess.graph_def,
input_saver_def=None,
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name="save/restore_all", # default value
filename_tensor_name="save/Const:0", # default value
output_graph=output_graph_path,
clear_devices=False,
initializer_nodes="")
if __name__ == "__main__":

View File

@ -8,4 +8,8 @@ tf_module {
name: "convert_op_hints_to_stubs"
argspec: "args=[\'session\', \'graph_def\', \'write_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'<function <lambda> instance>\'], "
}
member_method {
name: "get_potentially_supported_ops"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -605,6 +605,8 @@ renames = {
'tf.compat.v1.lite.constants.TFLITE',
'tf.lite.experimental.convert_op_hints_to_stubs':
'tf.compat.v1.lite.experimental.convert_op_hints_to_stubs',
'tf.lite.experimental.get_potentially_supported_ops':
'tf.compat.v1.lite.experimental.get_potentially_supported_ops',
'tf.lite.experimental.nn.TFLiteLSTMCell':
'tf.compat.v1.lite.experimental.nn.TFLiteLSTMCell',
'tf.lite.experimental.nn.TfLiteRNNCell':