Adds tests to lite_test.py to test functions in 1.X.
1. Adds support for a Variable --> Identity --> ReadVariableOp graph to convert_variables_to_constants. 2. Adds a Grappler pass before freezing the graph in lite.py in order to inline. PiperOrigin-RevId: 234030833
This commit is contained in:
parent
e7d9786e66
commit
6249068668
@ -38,6 +38,10 @@ from six import PY3
|
|||||||
|
|
||||||
from google.protobuf import text_format as _text_format
|
from google.protobuf import text_format as _text_format
|
||||||
from google.protobuf.message import DecodeError
|
from google.protobuf.message import DecodeError
|
||||||
|
from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
||||||
|
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
||||||
|
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
||||||
|
from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2
|
||||||
from tensorflow.lite.python import lite_constants as constants
|
from tensorflow.lite.python import lite_constants as constants
|
||||||
from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
|
from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
|
||||||
from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import
|
from tensorflow.lite.python.convert import ConverterError # pylint: disable=unused-import
|
||||||
@ -54,15 +58,12 @@ from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=un
|
|||||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
|
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
|
||||||
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
|
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
|
||||||
from tensorflow.lite.python.optimize import calibrator as _calibrator
|
from tensorflow.lite.python.optimize import calibrator as _calibrator
|
||||||
from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
|
||||||
from tensorflow.core.protobuf import rewriter_config_pb2 as _rewriter_config_pb2
|
|
||||||
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
|
||||||
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
|
||||||
from tensorflow.python import keras as _keras
|
from tensorflow.python import keras as _keras
|
||||||
from tensorflow.python.client import session as _session
|
from tensorflow.python.client import session as _session
|
||||||
from tensorflow.python.framework import graph_util as _tf_graph_util
|
from tensorflow.python.framework import graph_util as _tf_graph_util
|
||||||
from tensorflow.python.framework import ops as _ops
|
from tensorflow.python.framework import ops as _ops
|
||||||
from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
|
from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
|
||||||
|
from tensorflow.python.framework.errors_impl import OpError as _OpError
|
||||||
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
|
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
|
||||||
from tensorflow.python.grappler import tf_optimizer as _tf_optimizer
|
from tensorflow.python.grappler import tf_optimizer as _tf_optimizer
|
||||||
from tensorflow.python.lib.io import file_io as _file_io
|
from tensorflow.python.lib.io import file_io as _file_io
|
||||||
@ -73,35 +74,6 @@ from tensorflow.python.util import deprecation as _deprecation
|
|||||||
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||||
|
|
||||||
|
|
||||||
def _run_graph_optimizations(graph_def, input_arrays, output_arrays):
|
|
||||||
"""Apply standard TensorFlow optimizations to the graph_def.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph_def: Frozen GraphDef to be optimized.
|
|
||||||
input_arrays: List of arrays that are considered inputs of the graph.
|
|
||||||
output_arrays: List of arrays that are considered outputs of the graph.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A new, optimized GraphDef.
|
|
||||||
"""
|
|
||||||
meta_graph = _export_meta_graph(graph_def=graph_def)
|
|
||||||
|
|
||||||
# We need to add a collection called 'train_op' so that grappler
|
|
||||||
# knows what the outputs are.
|
|
||||||
fetch_collection = _meta_graph_pb2.CollectionDef()
|
|
||||||
for array in input_arrays + output_arrays:
|
|
||||||
fetch_collection.node_list.value.append(array.name)
|
|
||||||
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
|
|
||||||
|
|
||||||
config = _config_pb2.ConfigProto()
|
|
||||||
rewrite_options = config.graph_options.rewrite_options
|
|
||||||
rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.ON
|
|
||||||
# Avoid remapping as it creates ops like _FusedConv2D, which are not
|
|
||||||
# supported by TF Lite.
|
|
||||||
rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF
|
|
||||||
return _tf_optimizer.OptimizeGraph(config, meta_graph)
|
|
||||||
|
|
||||||
|
|
||||||
@_tf_export("lite.Optimize")
|
@_tf_export("lite.Optimize")
|
||||||
class Optimize(enum.Enum):
|
class Optimize(enum.Enum):
|
||||||
"""Enum defining the optimizations to apply when generating tflite graphs.
|
"""Enum defining the optimizations to apply when generating tflite graphs.
|
||||||
@ -311,7 +283,7 @@ class TFLiteConverter(object):
|
|||||||
Returns:
|
Returns:
|
||||||
TFLiteConverter class.
|
TFLiteConverter class.
|
||||||
"""
|
"""
|
||||||
graph_def = _freeze_graph(sess, output_tensors)
|
graph_def = _freeze_graph(sess, input_tensors, output_tensors)
|
||||||
return cls(graph_def, input_tensors, output_tensors)
|
return cls(graph_def, input_tensors, output_tensors)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -484,7 +456,7 @@ class TFLiteConverter(object):
|
|||||||
output_tensors = keras_model.outputs
|
output_tensors = keras_model.outputs
|
||||||
_set_tensor_shapes(input_tensors, input_shapes)
|
_set_tensor_shapes(input_tensors, input_shapes)
|
||||||
|
|
||||||
graph_def = _freeze_graph(sess, output_tensors)
|
graph_def = _freeze_graph(sess, input_tensors, output_tensors)
|
||||||
return cls(graph_def, input_tensors, output_tensors)
|
return cls(graph_def, input_tensors, output_tensors)
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
@ -586,26 +558,26 @@ class TFLiteConverter(object):
|
|||||||
"dump_graphviz_video": self.dump_graphviz_video
|
"dump_graphviz_video": self.dump_graphviz_video
|
||||||
}
|
}
|
||||||
|
|
||||||
optimized_graph = None
|
# Run a Grappler pass if it is possible.
|
||||||
if self.inference_type == constants.QUANTIZED_UINT8:
|
graph_def = self._graph_def
|
||||||
optimized_graph = self._graph_def
|
if self.inference_type != constants.QUANTIZED_UINT8:
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
optimized_graph = _run_graph_optimizations(
|
graph_def = _run_graph_optimizations(
|
||||||
self._graph_def, self._input_tensors, self._output_tensors)
|
self._graph_def, self._input_tensors, self._output_tensors)
|
||||||
except Exception:
|
except (_OpError, ValueError):
|
||||||
optimized_graph = self._graph_def
|
print("Warning: Grappler optimization pass failed. "
|
||||||
|
"If this behavior is unexpected, please file a bug.")
|
||||||
|
|
||||||
# Converts model.
|
# Converts model.
|
||||||
if self._has_valid_tensors():
|
if self._has_valid_tensors():
|
||||||
result = _toco_convert_impl(
|
result = _toco_convert_impl(
|
||||||
input_data=optimized_graph,
|
input_data=graph_def,
|
||||||
input_tensors=self._input_tensors,
|
input_tensors=self._input_tensors,
|
||||||
output_tensors=self._output_tensors,
|
output_tensors=self._output_tensors,
|
||||||
**converter_kwargs)
|
**converter_kwargs)
|
||||||
else:
|
else:
|
||||||
result = _toco_convert_graph_def(
|
result = _toco_convert_graph_def(
|
||||||
input_data=optimized_graph,
|
input_data=graph_def,
|
||||||
input_arrays_with_shape=self._input_arrays_with_shape,
|
input_arrays_with_shape=self._input_arrays_with_shape,
|
||||||
output_arrays=self._output_arrays,
|
output_arrays=self._output_arrays,
|
||||||
**converter_kwargs)
|
**converter_kwargs)
|
||||||
@ -710,6 +682,35 @@ class TocoConverter(object):
|
|||||||
input_shapes, output_arrays)
|
input_shapes, output_arrays)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_graph_optimizations(graph_def, input_arrays, output_arrays):
|
||||||
|
"""Apply standard TensorFlow optimizations to the graph_def.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_def: Frozen GraphDef to be optimized.
|
||||||
|
input_arrays: List of arrays that are considered inputs of the graph.
|
||||||
|
output_arrays: List of arrays that are considered outputs of the graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new, optimized GraphDef.
|
||||||
|
"""
|
||||||
|
meta_graph = _export_meta_graph(graph_def=graph_def)
|
||||||
|
|
||||||
|
# We need to add a collection called 'train_op' so that grappler
|
||||||
|
# knows what the outputs are.
|
||||||
|
fetch_collection = _meta_graph_pb2.CollectionDef()
|
||||||
|
for array in input_arrays + output_arrays:
|
||||||
|
fetch_collection.node_list.value.append(array.name)
|
||||||
|
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
|
||||||
|
|
||||||
|
config = _config_pb2.ConfigProto()
|
||||||
|
rewrite_options = config.graph_options.rewrite_options
|
||||||
|
rewrite_options.layout_optimizer = _rewriter_config_pb2.RewriterConfig.OFF
|
||||||
|
# Avoid remapping as it creates ops like _FusedConv2D, which are not
|
||||||
|
# supported by TF Lite.
|
||||||
|
rewrite_options.remapping = _rewriter_config_pb2.RewriterConfig.OFF
|
||||||
|
return _tf_optimizer.OptimizeGraph(config, meta_graph)
|
||||||
|
|
||||||
|
|
||||||
def _is_frozen_graph(sess):
|
def _is_frozen_graph(sess):
|
||||||
"""Determines if the graph is frozen.
|
"""Determines if the graph is frozen.
|
||||||
|
|
||||||
@ -728,22 +729,28 @@ def _is_frozen_graph(sess):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _freeze_graph(sess, output_tensors):
|
def _freeze_graph(sess, input_tensors, output_tensors):
|
||||||
"""Returns a frozen GraphDef.
|
"""Returns a frozen GraphDef.
|
||||||
|
|
||||||
Freezes a graph with Variables in it. Otherwise the existing GraphDef is
|
Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
|
||||||
returned.
|
existing GraphDef is returned. The Grappler pass is only run on models that
|
||||||
|
are frozen in order to inline the functions in the graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sess: TensorFlow Session.
|
sess: TensorFlow Session.
|
||||||
|
input_tensors: List of input tensors.
|
||||||
output_tensors: List of output tensors (only .name is used from this).
|
output_tensors: List of output tensors (only .name is used from this).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Frozen GraphDef.
|
Frozen GraphDef.
|
||||||
"""
|
"""
|
||||||
|
# Runs a Grappler pass in order to inline any functions in the graph.
|
||||||
|
graph_def = _run_graph_optimizations(sess.graph_def, input_tensors,
|
||||||
|
output_tensors)
|
||||||
|
|
||||||
if not _is_frozen_graph(sess):
|
if not _is_frozen_graph(sess):
|
||||||
output_arrays = [_tensor_name(tensor) for tensor in output_tensors]
|
output_arrays = [_tensor_name(tensor) for tensor in output_tensors]
|
||||||
return _tf_graph_util.convert_variables_to_constants(
|
return _tf_graph_util.convert_variables_to_constants(
|
||||||
sess, sess.graph_def, output_arrays)
|
sess, graph_def, output_arrays)
|
||||||
else:
|
else:
|
||||||
return sess.graph_def
|
return sess.graph_def
|
||||||
|
@ -27,13 +27,16 @@ from tensorflow.lite.python import lite_constants
|
|||||||
from tensorflow.lite.python.interpreter import Interpreter
|
from tensorflow.lite.python.interpreter import Interpreter
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
|
from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
@ -597,6 +600,48 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
|||||||
interpreter = Interpreter(model_content=tflite_model)
|
interpreter = Interpreter(model_content=tflite_model)
|
||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
def testFunctions(self):
|
||||||
|
"""Tests tf.function in 1.X."""
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def plus_placeholder(x, placeholder):
|
||||||
|
return x + placeholder
|
||||||
|
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
placeholder = array_ops.placeholder(
|
||||||
|
dtype=dtypes.float32, shape=[1], name='input')
|
||||||
|
variable_node = variables.Variable(1.0, name='variable_node')
|
||||||
|
defun_node = plus_placeholder(variable_node, placeholder)
|
||||||
|
output_node = math_ops.multiply(defun_node, 2.0, name='output_node')
|
||||||
|
|
||||||
|
# Initialize variables in the model.
|
||||||
|
sess = session.Session()
|
||||||
|
sess.run(variables.variables_initializer([variable_node]))
|
||||||
|
|
||||||
|
# Convert model and ensure model is not None.
|
||||||
|
converter = lite.TFLiteConverter.from_session(sess, [placeholder],
|
||||||
|
[output_node])
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
self.assertTrue(tflite_model)
|
||||||
|
|
||||||
|
# Check values from converted model.
|
||||||
|
interpreter = Interpreter(model_content=tflite_model)
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
input_details = interpreter.get_input_details()
|
||||||
|
self.assertEqual(1, len(input_details))
|
||||||
|
self.assertEqual('input', input_details[0]['name'])
|
||||||
|
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||||
|
self.assertTrue(([1] == input_details[0]['shape']).all())
|
||||||
|
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||||
|
|
||||||
|
output_details = interpreter.get_output_details()
|
||||||
|
self.assertEqual(1, len(output_details))
|
||||||
|
self.assertEqual('output_node', output_details[0]['name'])
|
||||||
|
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||||
|
self.assertTrue(([1] == output_details[0]['shape']).all())
|
||||||
|
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
@test_util.run_v1_only('b/120545219')
|
||||||
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
||||||
|
@ -26,8 +26,10 @@ from tensorflow.lite.python import lite
|
|||||||
from tensorflow.lite.testing.model_coverage import model_coverage_lib as model_coverage
|
from tensorflow.lite.testing.model_coverage import model_coverage_lib as model_coverage
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -70,6 +72,26 @@ class EvaluateFrozenGraph(test.TestCase):
|
|||||||
model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'],
|
model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'],
|
||||||
['add', 'Mean'])
|
['add', 'Mean'])
|
||||||
|
|
||||||
|
@test_util.run_v1_only('b/120545219')
|
||||||
|
def testFunctions(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def plus_placeholder(x, placeholder):
|
||||||
|
return x + placeholder
|
||||||
|
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
placeholder = array_ops.placeholder(
|
||||||
|
dtype=dtypes.float32, shape=[1], name='input')
|
||||||
|
variable_node = constant_op.constant(1.0, name='variable_node')
|
||||||
|
defun_node = plus_placeholder(variable_node, placeholder)
|
||||||
|
_ = math_ops.multiply(defun_node, 2.0, name='output_node')
|
||||||
|
|
||||||
|
# Initialize variables in the model.
|
||||||
|
sess = session.Session()
|
||||||
|
|
||||||
|
filename = self._saveFrozenGraph(sess)
|
||||||
|
model_coverage.test_frozen_graph(filename, ['input'], ['output_node'])
|
||||||
|
|
||||||
def _getQuantizedModel(self):
|
def _getQuantizedModel(self):
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
with session.Session().as_default() as sess:
|
with session.Session().as_default() as sess:
|
||||||
|
@ -248,6 +248,15 @@ def convert_variables_to_constants(sess,
|
|||||||
found_variables = {}
|
found_variables = {}
|
||||||
variable_names = []
|
variable_names = []
|
||||||
variable_dict_names = []
|
variable_dict_names = []
|
||||||
|
identity_ops_input_map = {}
|
||||||
|
|
||||||
|
def is_found_variable(input_tensor_name):
|
||||||
|
# Determines if the `input_tensor_name` is in `found_variables` or is an
|
||||||
|
# Identity op with an input that is in `found_variables`.
|
||||||
|
return ((input_tensor_name in found_variables) or
|
||||||
|
(input_tensor_name in identity_ops_input_map and
|
||||||
|
identity_ops_input_map[input_tensor_name] in found_variables))
|
||||||
|
|
||||||
for node in inference_graph.node:
|
for node in inference_graph.node:
|
||||||
if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
|
if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
|
||||||
variable_name = node.name
|
variable_name = node.name
|
||||||
@ -261,6 +270,9 @@ def convert_variables_to_constants(sess,
|
|||||||
variable_names.append(variable_name + "/Read/ReadVariableOp:0")
|
variable_names.append(variable_name + "/Read/ReadVariableOp:0")
|
||||||
else:
|
else:
|
||||||
variable_names.append(variable_name + ":0")
|
variable_names.append(variable_name + ":0")
|
||||||
|
elif node.op == "Identity":
|
||||||
|
# Creates a map of Identity node names to the input names.
|
||||||
|
identity_ops_input_map[node.name] = node.input[0].split(":")[0]
|
||||||
if variable_names:
|
if variable_names:
|
||||||
returned_variables = sess.run(variable_names)
|
returned_variables = sess.run(variable_names)
|
||||||
else:
|
else:
|
||||||
@ -283,11 +295,15 @@ def convert_variables_to_constants(sess,
|
|||||||
tensor=tensor_util.make_tensor_proto(
|
tensor=tensor_util.make_tensor_proto(
|
||||||
data, dtype=dtype.type, shape=data.shape)))
|
data, dtype=dtype.type, shape=data.shape)))
|
||||||
how_many_converted += 1
|
how_many_converted += 1
|
||||||
elif input_node.op == "ReadVariableOp" and (
|
elif (input_node.op == "ReadVariableOp" and
|
||||||
input_node.input[0] in found_variables):
|
is_found_variable(input_node.input[0])):
|
||||||
# The preceding branch converts all VarHandleOps of ResourceVariables to
|
# The preceding branch converts all VarHandleOps of ResourceVariables to
|
||||||
# constants, so we need to convert the associated ReadVariableOps to
|
# constants, so we need to convert the associated ReadVariableOps to
|
||||||
# Identity ops.
|
# Identity ops.
|
||||||
|
#
|
||||||
|
# Handles the following cases:
|
||||||
|
# Variable --> ReadVariableOp
|
||||||
|
# Variable --> Identity --> ReadVariableOp
|
||||||
output_node.op = "Identity"
|
output_node.op = "Identity"
|
||||||
output_node.name = input_node.name
|
output_node.name = input_node.name
|
||||||
output_node.input.extend([input_node.input[0]])
|
output_node.input.extend([input_node.input[0]])
|
||||||
|
@ -217,16 +217,18 @@ class DeviceFunctionsTest(test.TestCase):
|
|||||||
self.assertNear(4.0, output, 0.00001)
|
self.assertNear(4.0, output, 0.00001)
|
||||||
variable_graph_def = sess.graph.as_graph_def()
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
|
||||||
# First get the constant_graph_def when variable_names_whitelist is set,
|
# Get the constant_graph_def.
|
||||||
# note that if variable_names_whitelist is not set an error will be
|
|
||||||
# thrown because unused_variable_node is not initialized.
|
|
||||||
constant_graph_def = graph_util.convert_variables_to_constants(
|
constant_graph_def = graph_util.convert_variables_to_constants(
|
||||||
sess,
|
sess, variable_graph_def, ["output_node"])
|
||||||
variable_graph_def, ["output_node"],
|
|
||||||
variable_names_whitelist=set(["variable_node"]))
|
|
||||||
|
|
||||||
|
# Ensure the library is copied and there are no variables after
|
||||||
|
# freezing.
|
||||||
self.assertEqual(variable_graph_def.library,
|
self.assertEqual(variable_graph_def.library,
|
||||||
constant_graph_def.library)
|
constant_graph_def.library)
|
||||||
|
for node in constant_graph_def.node:
|
||||||
|
self.assertNotIn(
|
||||||
|
node.op,
|
||||||
|
["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||||
|
|
||||||
def testConvertVariablesToConsts(self):
|
def testConvertVariablesToConsts(self):
|
||||||
self._test_variable_to_const_conversion(use_resource=False)
|
self._test_variable_to_const_conversion(use_resource=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user