Add get_potentially_supported_ops to output TF Lite compatible ops.
PiperOrigin-RevId: 241439699
This commit is contained in:
parent
269e8a35bd
commit
4ee64e012a
25
tensorflow/lite/experimental/tensorboard/BUILD
Normal file
25
tensorflow/lite/experimental/tensorboard/BUILD
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# TFLite modules to support TensorBoard plugin.
|
||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "ops_util",
|
||||||
|
srcs = ["ops_util.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/toco/python:tensorflow_wrap_toco",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "ops_util_test",
|
||||||
|
srcs = ["ops_util_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":ops_util",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
)
|
4
tensorflow/lite/experimental/tensorboard/README.md
Normal file
4
tensorflow/lite/experimental/tensorboard/README.md
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
This folder contains basic modules to support TFLite plugin for TensorBoard.
|
||||||
|
|
||||||
|
Warning: Everything in this directory is experimental and highly subject to
|
||||||
|
changes.
|
49
tensorflow/lite/experimental/tensorboard/ops_util.py
Normal file
49
tensorflow/lite/experimental/tensorboard/ops_util.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Ops util to handle ops for Lite."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
from tensorflow.lite.toco.python import tensorflow_wrap_toco
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
class SupportedOp(collections.namedtuple("SupportedOp", ["op"])):
|
||||||
|
"""Spec of supported ops.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: string of op name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("lite.experimental.get_potentially_supported_ops")
|
||||||
|
def get_potentially_supported_ops():
|
||||||
|
"""Returns operations potentially supported by TensorFlow Lite.
|
||||||
|
|
||||||
|
The potentially support list contains a list of ops that are partially or
|
||||||
|
fully supported, which is derived by simply scanning op names to check whether
|
||||||
|
they can be handled without real conversion and specific parameters.
|
||||||
|
|
||||||
|
Given that some ops may be partially supported, the optimal way to determine
|
||||||
|
if a model's operations are supported is by converting using the TensorFlow
|
||||||
|
Lite converter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of SupportedOp.
|
||||||
|
"""
|
||||||
|
ops = tensorflow_wrap_toco.TocoGetPotentiallySupportedOps()
|
||||||
|
return [SupportedOp(o["op"]) for o in ops]
|
39
tensorflow/lite/experimental/tensorboard/ops_util_test.py
Normal file
39
tensorflow/lite/experimental/tensorboard/ops_util_test.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for backend."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.lite.experimental.tensorboard import ops_util
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class OpsUtilTest(test.TestCase):
|
||||||
|
|
||||||
|
def testGetPotentiallySupportedOps(self):
|
||||||
|
ops = ops_util.get_potentially_supported_ops()
|
||||||
|
# See GetTensorFlowNodeConverterMap() in
|
||||||
|
# tensorflow/lite/toco/import_tensorflow.cc
|
||||||
|
self.assertIsInstance(ops, list)
|
||||||
|
# Test partial ops that surely exist in the list.
|
||||||
|
self.assertIn(ops_util.SupportedOp("Add"), ops)
|
||||||
|
self.assertIn(ops_util.SupportedOp("Log"), ops)
|
||||||
|
self.assertIn(ops_util.SupportedOp("Sigmoid"), ops)
|
||||||
|
self.assertIn(ops_util.SupportedOp("Softmax"), ops)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -73,6 +73,7 @@ py_library(
|
|||||||
":op_hint",
|
":op_hint",
|
||||||
":util",
|
":util",
|
||||||
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
|
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
|
||||||
|
"//tensorflow/lite/experimental/tensorboard:ops_util",
|
||||||
"//tensorflow/lite/python/optimize:calibrator",
|
"//tensorflow/lite/python/optimize:calibrator",
|
||||||
"//tensorflow/python:graph_util",
|
"//tensorflow/python:graph_util",
|
||||||
"//tensorflow/python/keras",
|
"//tensorflow/python/keras",
|
||||||
|
@ -24,9 +24,11 @@ from six import PY3
|
|||||||
|
|
||||||
from google.protobuf import text_format as _text_format
|
from google.protobuf import text_format as _text_format
|
||||||
from google.protobuf.message import DecodeError
|
from google.protobuf.message import DecodeError
|
||||||
|
from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
||||||
from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn # pylint: disable=unused-import
|
from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn # pylint: disable=unused-import
|
||||||
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell # pylint: disable=unused-import
|
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell # pylint: disable=unused-import
|
||||||
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell # pylint: disable=unused-import
|
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell # pylint: disable=unused-import
|
||||||
|
from tensorflow.lite.experimental.tensorboard.ops_util import get_potentially_supported_ops # pylint: disable=unused-import
|
||||||
from tensorflow.lite.python import lite_constants as constants
|
from tensorflow.lite.python import lite_constants as constants
|
||||||
from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
|
from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
|
||||||
from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import
|
from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import
|
||||||
@ -47,7 +49,6 @@ from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_te
|
|||||||
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
|
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
|
||||||
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
|
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
|
||||||
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
|
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
|
||||||
from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
|
||||||
from tensorflow.python import keras as _keras
|
from tensorflow.python import keras as _keras
|
||||||
from tensorflow.python.client import session as _session
|
from tensorflow.python.client import session as _session
|
||||||
from tensorflow.python.eager import def_function as _def_function
|
from tensorflow.python.eager import def_function as _def_function
|
||||||
|
@ -1394,5 +1394,11 @@ class FromKerasFile(test_util.TensorFlowTestCase):
|
|||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
|
||||||
|
class ImportOpsUtilTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def testGetPotentiallySupportedOps(self):
|
||||||
|
self.assertIsNotNone(lite.get_potentially_supported_ops())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -2629,4 +2629,16 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
|||||||
}
|
}
|
||||||
return ImportTensorFlowGraphDef(model_flags, tf_import_flags, *tf_graph);
|
return ImportTensorFlowGraphDef(model_flags, tf_import_flags, *tf_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> GetPotentiallySupportedOps() {
|
||||||
|
std::vector<std::string> supported_ops;
|
||||||
|
const internal::ConverterMapType& converter_map =
|
||||||
|
internal::GetTensorFlowNodeConverterMap();
|
||||||
|
|
||||||
|
for (const auto& item : converter_map) {
|
||||||
|
supported_ops.push_back(item.first);
|
||||||
|
}
|
||||||
|
return supported_ops;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace toco
|
} // namespace toco
|
||||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/lite/toco/model.h"
|
#include "tensorflow/lite/toco/model.h"
|
||||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
|
||||||
|
|
||||||
namespace toco {
|
namespace toco {
|
||||||
|
|
||||||
@ -34,14 +34,20 @@ struct TensorFlowImportFlags {
|
|||||||
bool import_all_ops_as_unsupported = false;
|
bool import_all_ops_as_unsupported = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Converts TOCO model from TensorFlow GraphDef with given flags.
|
||||||
std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
||||||
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
|
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
|
||||||
const tensorflow::GraphDef& graph_def);
|
const tensorflow::GraphDef& graph_def);
|
||||||
|
|
||||||
|
// Converts TOCO model from the file content of TensorFlow GraphDef with given
|
||||||
|
// flags.
|
||||||
std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
std::unique_ptr<Model> ImportTensorFlowGraphDef(
|
||||||
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
|
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
|
||||||
const string& input_file_contents);
|
const string& input_file_contents);
|
||||||
|
|
||||||
|
// Gets a list of supported ops by their names.
|
||||||
|
std::vector<std::string> GetPotentiallySupportedOps();
|
||||||
|
|
||||||
} // namespace toco
|
} // namespace toco
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_TOCO_IMPORT_TENSORFLOW_H_
|
#endif // TENSORFLOW_LITE_TOCO_IMPORT_TENSORFLOW_H_
|
||||||
|
@ -32,4 +32,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
|||||||
PyObject* input_contents_txt_raw,
|
PyObject* input_contents_txt_raw,
|
||||||
bool extended_return = false);
|
bool extended_return = false);
|
||||||
|
|
||||||
|
// Returns a list of names of all ops potentially supported by tflite.
|
||||||
|
PyObject* TocoGetPotentiallySupportedOps();
|
||||||
|
|
||||||
} // namespace toco
|
} // namespace toco
|
@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||||
|
#include "tensorflow/lite/toco/import_tensorflow.h"
|
||||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||||
#include "tensorflow/lite/toco/python/toco_python_api.h"
|
#include "tensorflow/lite/toco/python/toco_python_api.h"
|
||||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||||
@ -49,21 +51,32 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
|||||||
bool error;
|
bool error;
|
||||||
std::string model_flags_proto_txt =
|
std::string model_flags_proto_txt =
|
||||||
ConvertArg(model_flags_proto_txt_raw, &error);
|
ConvertArg(model_flags_proto_txt_raw, &error);
|
||||||
if (error) return nullptr;
|
if (error) {
|
||||||
|
PyErr_SetString(PyExc_ValueError, "Model flags are invalid.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
std::string toco_flags_proto_txt =
|
std::string toco_flags_proto_txt =
|
||||||
ConvertArg(toco_flags_proto_txt_raw, &error);
|
ConvertArg(toco_flags_proto_txt_raw, &error);
|
||||||
if (error) return nullptr;
|
if (error) {
|
||||||
|
PyErr_SetString(PyExc_ValueError, "Toco flags are invalid.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
|
std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
|
||||||
if (error) return nullptr;
|
if (error) {
|
||||||
|
PyErr_SetString(PyExc_ValueError, "Input GraphDef is invalid.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
// Use TOCO to produce new outputs.
|
// Use TOCO to produce new outputs.
|
||||||
toco::ModelFlags model_flags;
|
toco::ModelFlags model_flags;
|
||||||
if (!model_flags.ParseFromString(model_flags_proto_txt)) {
|
if (!model_flags.ParseFromString(model_flags_proto_txt)) {
|
||||||
LOG(FATAL) << "Model proto failed to parse." << std::endl;
|
PyErr_SetString(PyExc_ValueError, "Model proto failed to parse.");
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
toco::TocoFlags toco_flags;
|
toco::TocoFlags toco_flags;
|
||||||
if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
|
if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
|
||||||
LOG(FATAL) << "Toco proto failed to parse." << std::endl;
|
PyErr_SetString(PyExc_ValueError, "Toco proto failed to parse.");
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& dump_options = *GraphVizDumpOptions::singleton();
|
auto& dump_options = *GraphVizDumpOptions::singleton();
|
||||||
@ -100,4 +113,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
|||||||
output_file_contents_txt.data(), output_file_contents_txt.size());
|
output_file_contents_txt.data(), output_file_contents_txt.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* TocoGetPotentiallySupportedOps() {
|
||||||
|
std::vector<std::string> supported_ops = toco::GetPotentiallySupportedOps();
|
||||||
|
PyObject* list = PyList_New(supported_ops.size());
|
||||||
|
for (size_t i = 0; i < supported_ops.size(); ++i) {
|
||||||
|
const string& op = supported_ops[i];
|
||||||
|
PyObject* op_dict = PyDict_New();
|
||||||
|
PyDict_SetItemString(op_dict, "op", PyUnicode_FromString(op.c_str()));
|
||||||
|
PyList_SetItem(list, i, op_dict);
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace toco
|
} // namespace toco
|
||||||
|
@ -31,6 +31,9 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
|
|||||||
PyObject* input_contents_txt_raw,
|
PyObject* input_contents_txt_raw,
|
||||||
bool extended_return = false);
|
bool extended_return = false);
|
||||||
|
|
||||||
|
// Returns a list of names of all ops potentially supported by tflite.
|
||||||
|
PyObject* TocoGetPotentiallySupportedOps();
|
||||||
|
|
||||||
} // namespace toco
|
} // namespace toco
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
|
#endif // TENSORFLOW_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
|
||||||
|
@ -125,12 +125,12 @@ def freeze_graph_with_def_protos(input_graph_def,
|
|||||||
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
|
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
|
||||||
if (not input_saved_model_dir and
|
if (not input_saved_model_dir and
|
||||||
not checkpoint_management.checkpoint_exists(input_checkpoint)):
|
not checkpoint_management.checkpoint_exists(input_checkpoint)):
|
||||||
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
|
raise ValueError("Input checkpoint '" + input_checkpoint +
|
||||||
return -1
|
"' doesn't exist!")
|
||||||
|
|
||||||
if not output_node_names:
|
if not output_node_names:
|
||||||
print("You need to supply the name of a node to --output_node_names.")
|
raise ValueError(
|
||||||
return -1
|
"You need to supply the name of a node to --output_node_names.")
|
||||||
|
|
||||||
# Remove all the explicit device specifications for this node. This helps to
|
# Remove all the explicit device specifications for this node. This helps to
|
||||||
# make the graph more portable.
|
# make the graph more portable.
|
||||||
@ -193,13 +193,14 @@ def freeze_graph_with_def_protos(input_graph_def,
|
|||||||
# tensors. Partition variables are Identity tensors that cannot be
|
# tensors. Partition variables are Identity tensors that cannot be
|
||||||
# handled by Saver.
|
# handled by Saver.
|
||||||
if has_partition_var:
|
if has_partition_var:
|
||||||
print("Models containing partition variables cannot be converted "
|
raise ValueError(
|
||||||
|
"Models containing partition variables cannot be converted "
|
||||||
"from checkpoint files. Please pass in a SavedModel using "
|
"from checkpoint files. Please pass in a SavedModel using "
|
||||||
"the flag --input_saved_model_dir.")
|
"the flag --input_saved_model_dir.")
|
||||||
return -1
|
|
||||||
# Models that have been frozen previously do not contain Variables.
|
# Models that have been frozen previously do not contain Variables.
|
||||||
elif _has_no_variables(sess):
|
elif _has_no_variables(sess):
|
||||||
print("No variables were found in this model. It is likely the model "
|
raise ValueError(
|
||||||
|
"No variables were found in this model. It is likely the model "
|
||||||
"was frozen previously. You cannot freeze a graph twice.")
|
"was frozen previously. You cannot freeze a graph twice.")
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
@ -242,8 +243,7 @@ def freeze_graph_with_def_protos(input_graph_def,
|
|||||||
def _parse_input_graph_proto(input_graph, input_binary):
|
def _parse_input_graph_proto(input_graph, input_binary):
|
||||||
"""Parses input tensorflow graph into GraphDef proto."""
|
"""Parses input tensorflow graph into GraphDef proto."""
|
||||||
if not gfile.Exists(input_graph):
|
if not gfile.Exists(input_graph):
|
||||||
print("Input graph file '" + input_graph + "' does not exist!")
|
raise IOError("Input graph file '" + input_graph + "' does not exist!")
|
||||||
return -1
|
|
||||||
input_graph_def = graph_pb2.GraphDef()
|
input_graph_def = graph_pb2.GraphDef()
|
||||||
mode = "rb" if input_binary else "r"
|
mode = "rb" if input_binary else "r"
|
||||||
with gfile.GFile(input_graph, mode) as f:
|
with gfile.GFile(input_graph, mode) as f:
|
||||||
@ -257,8 +257,7 @@ def _parse_input_graph_proto(input_graph, input_binary):
|
|||||||
def _parse_input_meta_graph_proto(input_graph, input_binary):
|
def _parse_input_meta_graph_proto(input_graph, input_binary):
|
||||||
"""Parses input tensorflow graph into MetaGraphDef proto."""
|
"""Parses input tensorflow graph into MetaGraphDef proto."""
|
||||||
if not gfile.Exists(input_graph):
|
if not gfile.Exists(input_graph):
|
||||||
print("Input meta graph file '" + input_graph + "' does not exist!")
|
raise IOError("Input meta graph file '" + input_graph + "' does not exist!")
|
||||||
return -1
|
|
||||||
input_meta_graph_def = MetaGraphDef()
|
input_meta_graph_def = MetaGraphDef()
|
||||||
mode = "rb" if input_binary else "r"
|
mode = "rb" if input_binary else "r"
|
||||||
with gfile.GFile(input_graph, mode) as f:
|
with gfile.GFile(input_graph, mode) as f:
|
||||||
@ -273,8 +272,7 @@ def _parse_input_meta_graph_proto(input_graph, input_binary):
|
|||||||
def _parse_input_saver_proto(input_saver, input_binary):
|
def _parse_input_saver_proto(input_saver, input_binary):
|
||||||
"""Parses input tensorflow Saver into SaverDef proto."""
|
"""Parses input tensorflow Saver into SaverDef proto."""
|
||||||
if not gfile.Exists(input_saver):
|
if not gfile.Exists(input_saver):
|
||||||
print("Input saver file '" + input_saver + "' does not exist!")
|
raise IOError("Input saver file '" + input_saver + "' does not exist!")
|
||||||
return -1
|
|
||||||
mode = "rb" if input_binary else "r"
|
mode = "rb" if input_binary else "r"
|
||||||
with gfile.GFile(input_saver, mode) as f:
|
with gfile.GFile(input_saver, mode) as f:
|
||||||
saver_def = saver_pb2.SaverDef()
|
saver_def = saver_pb2.SaverDef()
|
||||||
@ -369,9 +367,8 @@ def main(unused_args, flags):
|
|||||||
elif flags.checkpoint_version == 2:
|
elif flags.checkpoint_version == 2:
|
||||||
checkpoint_version = saver_pb2.SaverDef.V2
|
checkpoint_version = saver_pb2.SaverDef.V2
|
||||||
else:
|
else:
|
||||||
print("Invalid checkpoint version (must be '1' or '2'): %d" %
|
raise ValueError("Invalid checkpoint version (must be '1' or '2'): %d" %
|
||||||
flags.checkpoint_version)
|
flags.checkpoint_version)
|
||||||
return -1
|
|
||||||
freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
|
freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
|
||||||
flags.input_checkpoint, flags.output_node_names,
|
flags.input_checkpoint, flags.output_node_names,
|
||||||
flags.restore_op_name, flags.filename_tensor_name,
|
flags.restore_op_name, flags.filename_tensor_name,
|
||||||
@ -380,7 +377,9 @@ def main(unused_args, flags):
|
|||||||
flags.input_meta_graph, flags.input_saved_model_dir,
|
flags.input_meta_graph, flags.input_saved_model_dir,
|
||||||
flags.saved_model_tags, checkpoint_version)
|
flags.saved_model_tags, checkpoint_version)
|
||||||
|
|
||||||
|
|
||||||
def run_main():
|
def run_main():
|
||||||
|
"""Main function of freeze_graph."""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -487,5 +486,6 @@ def run_main():
|
|||||||
my_main = lambda unused_args: main(unused_args, flags)
|
my_main = lambda unused_args: main(unused_args, flags)
|
||||||
app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
|
app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
run_main()
|
run_main()
|
||||||
|
@ -316,7 +316,8 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
|||||||
output_node_names = "save/restore_all"
|
output_node_names = "save/restore_all"
|
||||||
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
|
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
|
||||||
|
|
||||||
return_value = freeze_graph.freeze_graph_with_def_protos(
|
with self.assertRaises(ValueError):
|
||||||
|
freeze_graph.freeze_graph_with_def_protos(
|
||||||
input_graph_def=sess.graph_def,
|
input_graph_def=sess.graph_def,
|
||||||
input_saver_def=None,
|
input_saver_def=None,
|
||||||
input_checkpoint=checkpoint_path,
|
input_checkpoint=checkpoint_path,
|
||||||
@ -326,7 +327,6 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
|||||||
output_graph=output_graph_path,
|
output_graph=output_graph_path,
|
||||||
clear_devices=False,
|
clear_devices=False,
|
||||||
initializer_nodes="")
|
initializer_nodes="")
|
||||||
self.assertTrue(return_value, -1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -8,4 +8,8 @@ tf_module {
|
|||||||
name: "convert_op_hints_to_stubs"
|
name: "convert_op_hints_to_stubs"
|
||||||
argspec: "args=[\'session\', \'graph_def\', \'write_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'<function <lambda> instance>\'], "
|
argspec: "args=[\'session\', \'graph_def\', \'write_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'<function <lambda> instance>\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_potentially_supported_ops"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,4 +8,8 @@ tf_module {
|
|||||||
name: "convert_op_hints_to_stubs"
|
name: "convert_op_hints_to_stubs"
|
||||||
argspec: "args=[\'session\', \'graph_def\', \'write_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'<function <lambda> instance>\'], "
|
argspec: "args=[\'session\', \'graph_def\', \'write_callback\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'<function <lambda> instance>\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_potentially_supported_ops"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user