Unify convert to constants logic.
Before this change there were two different code paths for dealing with graph freezing for v1 and v2 graphs. They largely did the same thing, but each path had and lacked capabilities the other had, and each had its own bugs. This change re-writes the previous v2 logic so it can cope with session-based graphs and allow for the conversion of a subset of the variables, and changes the previous convert_variables_to_constants call to proxy into the new logic. The new logic is built around a more "graphy" algorithm: variables are converted to constants, and that conversion is then propagated through the graph by following the graph edges. This hopefully makes it easier to understand what is going on, and to change it later on. More granular tests were added, in order to check that the right graph manipulations were performed. In order to do that, some graph merging infrastructure had to be created in the test. PiperOrigin-RevId: 315497124 Change-Id: I3a33acc804b5dc9628c208df8fd1b7c59f906ddb
This commit is contained in:
parent
75f09e6973
commit
ed52277026
@ -6633,6 +6633,7 @@ tf_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":client",
|
":client",
|
||||||
":client_testlib",
|
":client_testlib",
|
||||||
|
":control_flow_v2_toggles",
|
||||||
":framework",
|
":framework",
|
||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
":math_ops",
|
":math_ops",
|
||||||
@ -6650,9 +6651,11 @@ tf_py_test(
|
|||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = ["no_rocm"],
|
tags = ["no_rocm"],
|
||||||
deps = [
|
deps = [
|
||||||
"client_testlib",
|
":client_testlib",
|
||||||
"framework_test_lib",
|
":control_flow_v2_toggles",
|
||||||
":convert_to_constants",
|
":convert_to_constants",
|
||||||
|
":framework_test_lib",
|
||||||
|
":math_ops",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -19,33 +19,129 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from google.protobuf import text_format
|
||||||
|
from tensorflow.core.framework import function_pb2
|
||||||
|
from tensorflow.core.framework import graph_pb2
|
||||||
|
from tensorflow.core.framework import node_def_pb2
|
||||||
|
from tensorflow.core.framework import op_def_pb2
|
||||||
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
from tensorflow.python.client import session as session_lib
|
from tensorflow.python.client import session as session_lib
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import convert_to_constants
|
from tensorflow.python.framework import convert_to_constants
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import function
|
||||||
|
from tensorflow.python.framework import importer
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.grappler import tf_optimizer
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import cond_v2
|
from tensorflow.python.ops import cond_v2
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import control_flow_v2_toggles
|
||||||
|
from tensorflow.python.ops import gen_math_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import rnn
|
from tensorflow.python.ops import rnn
|
||||||
from tensorflow.python.ops import rnn_cell_impl
|
from tensorflow.python.ops import rnn_cell_impl
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops import while_v2
|
from tensorflow.python.ops import while_v2
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.saved_model import simple_save
|
from tensorflow.python.saved_model import simple_save
|
||||||
from tensorflow.python.saved_model.load import load
|
from tensorflow.python.saved_model.load import load
|
||||||
from tensorflow.python.saved_model.save import save
|
from tensorflow.python.saved_model.save import save
|
||||||
|
from tensorflow.python.training.saver import export_meta_graph
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
class _GraphMerger(object):
|
||||||
|
"""GraphDef merging methods for testing purposes."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_any(x1, x2, empty_fn):
|
||||||
|
"""Merges two values using the message's CopyFrom/MergeFrom methods."""
|
||||||
|
merged = empty_fn()
|
||||||
|
merged.CopyFrom(x1)
|
||||||
|
merged.MergeFrom(x2)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_nodes(node1, node2):
|
||||||
|
"""Merges two NodeDef messages."""
|
||||||
|
merged = _GraphMerger.merge_any(node1, node2, node_def_pb2.NodeDef)
|
||||||
|
merged_inputs = node1.input[:]
|
||||||
|
merged_inputs.extend([i for i in node2.input[:] if i not in merged_inputs])
|
||||||
|
merged.input[:] = merged_inputs
|
||||||
|
return merged
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_lists(repeated1, repeated2, empty_fn, key_fn, merge_fn):
|
||||||
|
"""Merges two lists representing maps."""
|
||||||
|
merged = {}
|
||||||
|
xs1 = {key_fn(x): x for x in repeated1}
|
||||||
|
xs2 = {key_fn(x): x for x in repeated2}
|
||||||
|
for name in set().union(xs1.keys(), xs2.keys()):
|
||||||
|
x1 = empty_fn() if name not in xs1 else xs1[name]
|
||||||
|
x2 = empty_fn() if name not in xs2 else xs2[name]
|
||||||
|
merged[name] = merge_fn(x1, x2)
|
||||||
|
return sorted(merged.values(), key=key_fn)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_node_lists(repeated_nodes1, repeated_nodes2):
|
||||||
|
"""Merges two repeated node fields."""
|
||||||
|
return _GraphMerger.merge_lists(repeated_nodes1, repeated_nodes2,
|
||||||
|
node_def_pb2.NodeDef, lambda n: n.name,
|
||||||
|
_GraphMerger.merge_nodes)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_functions(fn1, fn2):
|
||||||
|
"""Merges two FunctionDefs."""
|
||||||
|
merged = _GraphMerger.merge_any(fn1, fn2, function_pb2.FunctionDef)
|
||||||
|
|
||||||
|
del merged.signature.input_arg[:]
|
||||||
|
merged.signature.input_arg.extend(
|
||||||
|
_GraphMerger.merge_lists(
|
||||||
|
fn1.signature.input_arg[:], fn2.signature.input_arg[:],
|
||||||
|
op_def_pb2.OpDef.ArgDef, lambda a: a.name,
|
||||||
|
lambda x, y: _GraphMerger.merge_any(x, y, op_def_pb2.OpDef.ArgDef)))
|
||||||
|
|
||||||
|
del merged.signature.output_arg[:]
|
||||||
|
merged.signature.output_arg.extend(
|
||||||
|
_GraphMerger.merge_lists(
|
||||||
|
fn1.signature.output_arg[:], fn2.signature.output_arg[:],
|
||||||
|
op_def_pb2.OpDef.ArgDef, lambda a: a.name,
|
||||||
|
lambda x, y: _GraphMerger.merge_any(x, y, op_def_pb2.OpDef.ArgDef)))
|
||||||
|
|
||||||
|
del merged.node_def[:]
|
||||||
|
merged.node_def.extend(
|
||||||
|
_GraphMerger.merge_node_lists(fn1.node_def[:], fn2.node_def[:]))
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_graphs(graph1, graph2):
|
||||||
|
"""Merges two GraphDef messages."""
|
||||||
|
merged = graph_pb2.GraphDef()
|
||||||
|
merged.node.extend(
|
||||||
|
_GraphMerger.merge_node_lists(graph1.node[:], graph2.node[:]))
|
||||||
|
|
||||||
|
merged.library.function.extend(
|
||||||
|
_GraphMerger.merge_lists(graph1.library.function,
|
||||||
|
graph2.library.function,
|
||||||
|
function_pb2.FunctionDef,
|
||||||
|
lambda f: f.signature.name,
|
||||||
|
_GraphMerger.merge_functions))
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
class VariablesToConstantsTest(test.TestCase):
|
class VariablesToConstantsTest(test.TestCase):
|
||||||
|
|
||||||
def _freezeModel(self, model):
|
def _freezeModel(self, model):
|
||||||
@ -325,6 +421,7 @@ class VariablesToConstantsTest(test.TestCase):
|
|||||||
cell, seq, dtype=dtypes.float32, sequence_length=[1])
|
cell, seq, dtype=dtypes.float32, sequence_length=[1])
|
||||||
|
|
||||||
root, output_func = self._freezeModel(model)
|
root, output_func = self._freezeModel(model)
|
||||||
|
|
||||||
self._testConvertedFunction(root, root.f, output_func, input_data)
|
self._testConvertedFunction(root, root.f, output_func, input_data)
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
@ -347,6 +444,7 @@ class VariablesToConstantsTest(test.TestCase):
|
|||||||
return control_flow_ops.while_loop(condition, body, [x])
|
return control_flow_ops.while_loop(condition, body, [x])
|
||||||
|
|
||||||
root, output_func = self._freezeModel(model)
|
root, output_func = self._freezeModel(model)
|
||||||
|
|
||||||
self._testConvertedFunction(root, root.f, output_func, input_data)
|
self._testConvertedFunction(root, root.f, output_func, input_data)
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
@ -389,5 +487,665 @@ class VariablesToConstantsTest(test.TestCase):
|
|||||||
self._testConvertedFunction(root, root.f, output_func, input_data)
|
self._testConvertedFunction(root, root.f, output_func, input_data)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertVariablesToConstantsSessionTest(test.TestCase):
|
||||||
|
|
||||||
|
def _assertGraphContains(self, graph, subgraph):
|
||||||
|
"""Asserts that the given subgraph is contained within the given graph."""
|
||||||
|
|
||||||
|
def normalize_uids(msg):
|
||||||
|
"""Replace auto-id function names with something consistent."""
|
||||||
|
# These functions have non-deterministic names, the non-determinism coming
|
||||||
|
# from having an ops.uid() suffix in their names. We're replacing these
|
||||||
|
# with new sequential IDs starting from 0 for each prefix, which is
|
||||||
|
# is sufficient for tests.
|
||||||
|
if isinstance(msg, graph_pb2.GraphDef):
|
||||||
|
msg = text_format.MessageToString(msg)
|
||||||
|
name_prefixes = ["case_cond_true.*", "case_cond_false.*"]
|
||||||
|
name_regex = r"\b(" + "|".join(name_prefixes) + r")_([0-9]+)\b"
|
||||||
|
names = {}
|
||||||
|
for (name, index) in re.findall(name_regex, msg):
|
||||||
|
names.setdefault(name, set()).add(int(index))
|
||||||
|
for name, indices in names.items():
|
||||||
|
for new_index, old_index in enumerate(sorted(list(indices))):
|
||||||
|
msg = re.sub(r"\b" + name + "_" + str(old_index) + r"\b",
|
||||||
|
name + "_" + str(new_index), msg)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
norm_graph = text_format.Parse(normalize_uids(graph), graph_pb2.GraphDef())
|
||||||
|
norm_subgraph = text_format.Parse(
|
||||||
|
normalize_uids(subgraph), graph_pb2.GraphDef())
|
||||||
|
|
||||||
|
# Graph S is contained in C if and only if merge(C,S) == C.
|
||||||
|
# We merge the input graph with an empty graph to normalize repeated fields:
|
||||||
|
# assertProtoEquals is sensitive to ordering.
|
||||||
|
norm_graph = _GraphMerger.merge_graphs(norm_graph, graph_pb2.GraphDef())
|
||||||
|
merged_graph = _GraphMerger.merge_graphs(norm_graph, norm_subgraph)
|
||||||
|
self.assertProtoEquals(norm_graph, merged_graph)
|
||||||
|
|
||||||
|
def _ensure_no_variables_in_graph(self, graph_def):
|
||||||
|
"""Ensures there are no variables in the graph."""
|
||||||
|
for node in graph_def.node:
|
||||||
|
self.assertNotIn(
|
||||||
|
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
||||||
|
|
||||||
|
def _test_variable_to_const_conversion(self, use_resource):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with variable_scope.variable_scope("", use_resource=use_resource):
|
||||||
|
variable_node = variable_scope.get_variable(
|
||||||
|
"variable_node", initializer=1.0)
|
||||||
|
variable_scope.get_variable("unused_variable_node", initializer=1.0)
|
||||||
|
output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
self.evaluate(variable_node.initializer)
|
||||||
|
output = self.evaluate(output_node)
|
||||||
|
self.assertNear(2.0, output, 0.00001)
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
constant_graph_def = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
session=sess,
|
||||||
|
graph_def=variable_graph_def,
|
||||||
|
output_node_names=["output_node"]))
|
||||||
|
|
||||||
|
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||||
|
|
||||||
|
# Now we make sure the variable is now a constant, and that the graph still
|
||||||
|
# produces the expected result.
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
_ = importer.import_graph_def(constant_graph_def, name="")
|
||||||
|
self.assertEqual(4, len(constant_graph_def.node))
|
||||||
|
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
||||||
|
output = self.evaluate(output_node)
|
||||||
|
self.assertNear(2.0, output, 0.00001)
|
||||||
|
|
||||||
|
def test_resource_variable_can_be_written_after_blacklisting(self):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with variable_scope.variable_scope("", use_resource=True):
|
||||||
|
variable_node = variable_scope.get_variable(
|
||||||
|
"variable_node", initializer=1.0)
|
||||||
|
another_variable = variable_scope.get_variable(
|
||||||
|
"unused_variable_node", initializer=2.0)
|
||||||
|
with ops.control_dependencies(
|
||||||
|
[variable_node.assign(another_variable + variable_node)]):
|
||||||
|
output_node = array_ops.identity(variable_node, name="output_node")
|
||||||
|
initializer_name = variable_node.initializer.name
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
self.evaluate(variable_node.initializer)
|
||||||
|
self.evaluate(another_variable.initializer)
|
||||||
|
output = self.evaluate(output_node)
|
||||||
|
self.assertNear(3.0, output, 0.00001)
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
|
||||||
|
# Test variable name black list. This should result in the variable
|
||||||
|
# not being a const. Furthermore, the paths that read from and assign
|
||||||
|
# to the blacklisted variable should continue to be valid.
|
||||||
|
constant_graph_def_with_blacklist = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
session=sess,
|
||||||
|
graph_def=variable_graph_def,
|
||||||
|
output_node_names=["output_node", initializer_name],
|
||||||
|
variable_names_blacklist=set(["variable_node"])))
|
||||||
|
|
||||||
|
variable_node = None
|
||||||
|
for node in constant_graph_def_with_blacklist.node:
|
||||||
|
if node.name == "variable_node":
|
||||||
|
variable_node = node
|
||||||
|
self.assertIsNotNone(variable_node)
|
||||||
|
self.assertEqual(variable_node.op, "VarHandleOp")
|
||||||
|
|
||||||
|
# Now we make sure another_variable is now a constant, but the original
|
||||||
|
# variable is not, and that the graph can be executed and update the
|
||||||
|
# variable can be updated with each execution.
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
_ = importer.import_graph_def(constant_graph_def_with_blacklist, name="")
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
||||||
|
self.evaluate(sess.graph.get_operation_by_name(initializer_name))
|
||||||
|
output = self.evaluate(output_node)
|
||||||
|
self.assertNear(3.0, output, 0.00001)
|
||||||
|
output = self.evaluate(output_node)
|
||||||
|
self.assertNear(5.0, output, 0.00001)
|
||||||
|
|
||||||
|
def _inline_functions(self, graph_def, arrays):
|
||||||
|
meta_graph = export_meta_graph(graph_def=graph_def)
|
||||||
|
fetch_collection = meta_graph_pb2.CollectionDef()
|
||||||
|
for name in arrays:
|
||||||
|
fetch_collection.node_list.value.append(name)
|
||||||
|
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
|
||||||
|
|
||||||
|
# Initialize RewriterConfig with everything disabled except function
|
||||||
|
# inlining.
|
||||||
|
config = config_pb2.ConfigProto()
|
||||||
|
rewrite_options = config.graph_options.rewrite_options
|
||||||
|
rewrite_options.optimizers.append("function")
|
||||||
|
return tf_optimizer.OptimizeGraph(config, meta_graph)
|
||||||
|
|
||||||
|
def _test_convert_variables_with_functions(self, inline_functions):
|
||||||
|
"""Freezes a graph with functions."""
|
||||||
|
|
||||||
|
@function.Defun(dtypes.float32)
|
||||||
|
def plus_one(x):
|
||||||
|
return x + 1.0
|
||||||
|
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
variable_node = variables.Variable(1.0, name="variable_node")
|
||||||
|
_ = variables.Variable(1.0, name="unused_variable_node")
|
||||||
|
defun_node = plus_one(variable_node)
|
||||||
|
_ = math_ops.multiply(defun_node, 2.0, name="output_node")
|
||||||
|
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
self.evaluate(variables.variables_initializer([variable_node]))
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
|
||||||
|
if inline_functions:
|
||||||
|
# Run Grappler to create the VarOpHandle --> Placeholder -->
|
||||||
|
# ResourceVariable pattern.
|
||||||
|
variable_graph_def = self._inline_functions(
|
||||||
|
variable_graph_def, ["variable_node", "output_node"])
|
||||||
|
|
||||||
|
constant_graph_def = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
session=sess,
|
||||||
|
graph_def=variable_graph_def,
|
||||||
|
output_node_names=["output_node"]))
|
||||||
|
|
||||||
|
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||||
|
|
||||||
|
def testReferenceVariables(self):
|
||||||
|
"""Freezes a graph with reference variables."""
|
||||||
|
self._test_variable_to_const_conversion(use_resource=False)
|
||||||
|
|
||||||
|
def testResourceVariables(self):
|
||||||
|
"""Freezes a graph with resource variables."""
|
||||||
|
self._test_variable_to_const_conversion(use_resource=True)
|
||||||
|
|
||||||
|
def testWithFunctions(self):
|
||||||
|
"""Freezes a graph with functions."""
|
||||||
|
self._test_convert_variables_with_functions(inline_functions=False)
|
||||||
|
|
||||||
|
def testWithInlinedFunctions(self):
|
||||||
|
"""Freezes a graph with functions that have been inlined using Grappler."""
|
||||||
|
self._test_convert_variables_with_functions(inline_functions=True)
|
||||||
|
|
||||||
|
def testGraphWithSwitch(self):
|
||||||
|
"""Freezes a graph which contains a Switch with type RESOURCE_DT."""
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with variable_scope.variable_scope("", use_resource=True):
|
||||||
|
x = variable_scope.get_variable("var_x", initializer=1.0)
|
||||||
|
y = variable_scope.get_variable("var_y", initializer=2.0)
|
||||||
|
f1 = lambda: variable_scope.get_variable("var_f1", initializer=17.0)
|
||||||
|
f2 = lambda: variable_scope.get_variable("var_f2", initializer=23.0)
|
||||||
|
cond_node = control_flow_ops.case([(gen_math_ops.less(x, y), f1)],
|
||||||
|
default=f2)
|
||||||
|
_ = math_ops.multiply(cond_node, 2.0, name="output_node")
|
||||||
|
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
|
||||||
|
constant_graph_def = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
session=sess,
|
||||||
|
graph_def=variable_graph_def,
|
||||||
|
output_node_names=["output_node"]))
|
||||||
|
|
||||||
|
self._ensure_no_variables_in_graph(constant_graph_def)
|
||||||
|
|
||||||
|
def testConvertSingleVariable(self):
|
||||||
|
"""Tests that a single variable is properly converted to a constant."""
|
||||||
|
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with variable_scope.variable_scope("", use_resource=False):
|
||||||
|
_ = variable_scope.get_variable("x", initializer=1.0)
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
constant_graph_def = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
sess, variable_graph_def, ["x/read"]))
|
||||||
|
self._assertGraphContains(
|
||||||
|
constant_graph_def, """
|
||||||
|
node {
|
||||||
|
name: "x" op: "Const"
|
||||||
|
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "x/read" op: "Identity" input: "x"
|
||||||
|
attr { key: "T" value { type: DT_FLOAT } }
|
||||||
|
}""")
|
||||||
|
|
||||||
|
def testConvertSingleResourceVariable(self):
|
||||||
|
"""Tests that a resource variable is properly converted to a constant."""
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with variable_scope.variable_scope("", use_resource=True):
|
||||||
|
_ = variable_scope.get_variable("x", initializer=1.0)
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
constant_graph_def = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
sess, variable_graph_def, ["x/Read/ReadVariableOp"]))
|
||||||
|
self._assertGraphContains(
|
||||||
|
constant_graph_def, """
|
||||||
|
node {
|
||||||
|
name: "x" op: "Const"
|
||||||
|
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "x/Read/ReadVariableOp" op: "Identity" input: "x"
|
||||||
|
attr { key: "T" value { type: DT_FLOAT } }
|
||||||
|
}""")
|
||||||
|
|
||||||
|
def testConvertOneVariableOfTwo(self):
|
||||||
|
"""Tests that one variable can be kept unconverted."""
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with variable_scope.variable_scope("", use_resource=False):
|
||||||
|
x = variable_scope.get_variable("x", initializer=1.0)
|
||||||
|
y = variable_scope.get_variable("y", initializer=1.0)
|
||||||
|
_ = math_ops.multiply(x, y, name="out")
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
constant_graph_def = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
sess,
|
||||||
|
variable_graph_def, ["out"],
|
||||||
|
variable_names_blacklist=["y"]))
|
||||||
|
self._assertGraphContains(
|
||||||
|
constant_graph_def, """
|
||||||
|
node {
|
||||||
|
name: "x" op: "Const"
|
||||||
|
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "x/read" op: "Identity" input: "x"
|
||||||
|
attr { key: "T" value { type: DT_FLOAT } }
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "y" op: "VariableV2"
|
||||||
|
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "y/read" op: "Identity" input: "y"
|
||||||
|
attr { key: "T" value { type: DT_FLOAT } }
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "out" op: "Mul" input: "x/read" input: "y/read"
|
||||||
|
attr {key: "T" value {type: DT_FLOAT}}
|
||||||
|
}""")
|
||||||
|
|
||||||
|
def testConvertOneResourceVariableOfTwo(self):
|
||||||
|
"""Tests that one variable can be kept unconverted."""
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with variable_scope.variable_scope("", use_resource=True):
|
||||||
|
x = variable_scope.get_variable("x", initializer=1.0)
|
||||||
|
y = variable_scope.get_variable("y", initializer=1.0)
|
||||||
|
_ = math_ops.multiply(x, y, name="out")
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
constant_graph_def = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
sess,
|
||||||
|
variable_graph_def, ["out"],
|
||||||
|
variable_names_blacklist=["y"]))
|
||||||
|
self._assertGraphContains(
|
||||||
|
constant_graph_def, """
|
||||||
|
node {
|
||||||
|
name: "x" op: "Const"
|
||||||
|
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "y" op: "VarHandleOp"
|
||||||
|
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "out/ReadVariableOp" op: "Identity" input: "x"
|
||||||
|
attr { key: "T" value { type: DT_FLOAT } }
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "out/ReadVariableOp_1" op: "ReadVariableOp" input: "y"
|
||||||
|
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "out" op: "Mul"
|
||||||
|
input: "out/ReadVariableOp" input: "out/ReadVariableOp_1"
|
||||||
|
attr {key: "T" value {type: DT_FLOAT}}
|
||||||
|
}""")
|
||||||
|
|
||||||
|
def testConvertIdentityChain(self):
|
||||||
|
"""Tests that a chain of Identity ops is converted properly."""
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with variable_scope.variable_scope("", use_resource=True):
|
||||||
|
x = variable_scope.get_variable("x", initializer=1.0)
|
||||||
|
y = array_ops.identity(x, name="y")
|
||||||
|
_ = array_ops.identity(y, name="z")
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
constant_graph_def = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
sess, variable_graph_def, ["z"]))
|
||||||
|
self._assertGraphContains(
|
||||||
|
constant_graph_def, """
|
||||||
|
node {
|
||||||
|
name: "x" op: "Const"
|
||||||
|
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "y/ReadVariableOp" op: "Identity" input: "x"
|
||||||
|
attr { key: "T" value { type: DT_FLOAT } }
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "y" op: "Identity" input: "y/ReadVariableOp"
|
||||||
|
attr { key: "T" value { type: DT_FLOAT } }
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "z" op: "Identity" input: "y"
|
||||||
|
attr { key: "T" value { type: DT_FLOAT } }
|
||||||
|
}""")
|
||||||
|
|
||||||
|
def testConvertCase(self):
|
||||||
|
"""Tests that a v1 case() construction converts properly."""
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with variable_scope.variable_scope("", use_resource=False):
|
||||||
|
control_flow_v2_toggles.disable_control_flow_v2()
|
||||||
|
x = variable_scope.get_variable("x", initializer=1.0)
|
||||||
|
y = variable_scope.get_variable("y", initializer=2.0)
|
||||||
|
_ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: x)],
|
||||||
|
default=lambda: y)
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
constant_graph_def = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
sess, variable_graph_def, ["case/cond/Merge"]))
|
||||||
|
self._assertGraphContains(
|
||||||
|
constant_graph_def, """
|
||||||
|
node {
|
||||||
|
name: "x" op: "Const"
|
||||||
|
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "y" op: "Const"
|
||||||
|
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 2 }}}
|
||||||
|
}
|
||||||
|
node {name: "x/read" op: "Identity" input: "x"}
|
||||||
|
node {name: "y/read" op: "Identity" input: "y"}
|
||||||
|
node {name: "Less" op: "Less" input: "x/read" input: "y/read"}
|
||||||
|
node {name: "case/cond/pred_id" op: "Identity" input: "Less"}
|
||||||
|
node {
|
||||||
|
name: "case/cond/Switch_1" op: "Switch"
|
||||||
|
input: "case/cond/pred_id" input: "x/read"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "case/cond/Switch_2" op: "Switch"
|
||||||
|
input: "case/cond/pred_id" input: "y/read"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "case/cond/Merge" op: "Merge"
|
||||||
|
input: "case/cond/Switch_2" input: "case/cond/Switch_1:1"
|
||||||
|
attr {key: "T" value {type: DT_FLOAT}}
|
||||||
|
}""")
|
||||||
|
|
||||||
|
def testConvertV2Case(self):
|
||||||
|
"""Tests that a v2 case() converts properly."""
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with variable_scope.variable_scope("", use_resource=False):
|
||||||
|
control_flow_v2_toggles.enable_control_flow_v2()
|
||||||
|
a = variable_scope.get_variable("a", initializer=2.0)
|
||||||
|
x = variable_scope.get_variable("x", initializer=1.0)
|
||||||
|
y = variable_scope.get_variable("y", initializer=2.0)
|
||||||
|
_ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: a)],
|
||||||
|
default=lambda: y)
|
||||||
|
control_flow_v2_toggles.disable_control_flow_v2()
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
constant_graph_def = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
sess, variable_graph_def, ["case/cond"]))
|
||||||
|
self._assertGraphContains(
|
||||||
|
constant_graph_def, """
|
||||||
|
node {
|
||||||
|
name: "x" op: "Const"
|
||||||
|
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "y" op: "Const"
|
||||||
|
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 2 }}}
|
||||||
|
}
|
||||||
|
node {name: "x/read" op: "Identity" input: "x"}
|
||||||
|
node {name: "y/read" op: "Identity" input: "y"}
|
||||||
|
node {name: "Less" op: "Less" input: "x/read" input: "y/read"}
|
||||||
|
node {
|
||||||
|
name: "case/cond" op: "StatelessIf"
|
||||||
|
input: "Less" input: "a/read" input: "y/read"
|
||||||
|
attr {key: "Tcond" value {type: DT_BOOL}}
|
||||||
|
attr {key: "Tin" value {list {type: DT_FLOAT type: DT_FLOAT}}}
|
||||||
|
attr {key: "Tout" value {list {type: DT_FLOAT}}}
|
||||||
|
}
|
||||||
|
library {
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "case_cond_false_frozen_0"
|
||||||
|
input_arg {name: "placeholder" type: DT_FLOAT}
|
||||||
|
input_arg {name: "y_read_0" type: DT_FLOAT}
|
||||||
|
output_arg {name: "y_read" type: DT_FLOAT}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "case_cond_true_frozen_0"
|
||||||
|
input_arg {name: "a_read_0" type: DT_FLOAT}
|
||||||
|
input_arg {name: "placeholder" type: DT_FLOAT}
|
||||||
|
output_arg {name: "a_read" type: DT_FLOAT}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}""")
|
||||||
|
|
||||||
|
def testConvertV2ResourceCase(self):
|
||||||
|
"""Tests that a v2 case() with resource variables converts properly."""
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with variable_scope.variable_scope("", use_resource=True):
|
||||||
|
control_flow_v2_toggles.enable_control_flow_v2()
|
||||||
|
x = variable_scope.get_variable("x", initializer=1.0)
|
||||||
|
y = variable_scope.get_variable("y", initializer=2.0)
|
||||||
|
_ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: x)],
|
||||||
|
default=lambda: y)
|
||||||
|
control_flow_v2_toggles.disable_control_flow_v2()
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
constant_graph_def = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
sess, variable_graph_def, ["case/cond"]))
|
||||||
|
self._assertGraphContains(
|
||||||
|
constant_graph_def, """
|
||||||
|
node {name: "x" op: "Const"}
|
||||||
|
node {name: "y" op: "Const"}
|
||||||
|
node {
|
||||||
|
name: "case/cond" op: "If" input: "Less" input: "x" input: "y"
|
||||||
|
attr {key: "Tcond" value {type: DT_BOOL}}
|
||||||
|
attr {key: "Tin" value {list {type: DT_FLOAT type: DT_FLOAT}}}
|
||||||
|
attr {key: "Tout" value {list {type: DT_FLOAT}}}
|
||||||
|
}
|
||||||
|
library {
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "case_cond_false_frozen_0"
|
||||||
|
input_arg {name: "placeholder" type: DT_FLOAT}
|
||||||
|
input_arg {name: "readvariableop_y" type: DT_FLOAT}
|
||||||
|
output_arg {name: "readvariableop" type: DT_FLOAT}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "case_cond_true_frozen_0"
|
||||||
|
input_arg {name: "placeholder" type: DT_FLOAT}
|
||||||
|
input_arg {name: "readvariableop_x" type: DT_FLOAT}
|
||||||
|
output_arg {name: "readvariableop" type: DT_FLOAT}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}""")
|
||||||
|
|
||||||
|
def testConvertV2UnconvertedResourceNestedCase(self):
|
||||||
|
"""Tests unconverted variable propagation through nested functions."""
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with variable_scope.variable_scope("", use_resource=True):
|
||||||
|
control_flow_v2_toggles.enable_control_flow_v2()
|
||||||
|
x = variable_scope.get_variable("x", initializer=1.0)
|
||||||
|
y = variable_scope.get_variable("y", initializer=2.0)
|
||||||
|
z = variable_scope.get_variable("z", initializer=3.0)
|
||||||
|
# pylint: disable=g-long-lambda
|
||||||
|
_ = control_flow_ops.case(
|
||||||
|
[(gen_math_ops.less(x, y), lambda: x)],
|
||||||
|
default=lambda: control_flow_ops.case(
|
||||||
|
[(gen_math_ops.less(z, y), lambda: z)], default=lambda: y))
|
||||||
|
# pylint: enable=g-long-lambda
|
||||||
|
control_flow_v2_toggles.disable_control_flow_v2()
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
variable_graph_def = sess.graph.as_graph_def()
|
||||||
|
constant_graph_def = (
|
||||||
|
convert_to_constants
|
||||||
|
.convert_variables_to_constants_from_session_graph(
|
||||||
|
sess,
|
||||||
|
variable_graph_def, ["case/cond"],
|
||||||
|
variable_names_blacklist=["y"]))
|
||||||
|
self._assertGraphContains(
|
||||||
|
constant_graph_def, """
|
||||||
|
node {name: "x" op: "Const"}
|
||||||
|
node {name: "y" op: "VarHandleOp"}
|
||||||
|
node {name: "z" op: "Const"}
|
||||||
|
|
||||||
|
node {name: "Less/ReadVariableOp" op: "Identity" input: "x"}
|
||||||
|
node {name: "Less/ReadVariableOp_1" op: "ReadVariableOp" input: "y"}
|
||||||
|
|
||||||
|
node {
|
||||||
|
name: "case/cond" op: "If"
|
||||||
|
input: "x" input: "z" input: "y"
|
||||||
|
attr {
|
||||||
|
key: "Tin"
|
||||||
|
value {list
|
||||||
|
{type: DT_FLOAT type: DT_FLOAT type: DT_RESOURCE}}}
|
||||||
|
attr {
|
||||||
|
key: "_read_only_resource_inputs"
|
||||||
|
value {list {i: 1 i: 2 i: 3}}}
|
||||||
|
attr {key: "then_branch"
|
||||||
|
value {func {name: "case_cond_true_frozen_0"}}}
|
||||||
|
attr {key: "else_branch"
|
||||||
|
value {func {name: "case_cond_false_frozen_0"}}}
|
||||||
|
attr {key: "output_shapes" value {list {shape {}}}}
|
||||||
|
}
|
||||||
|
library {
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "case_cond_true_frozen_0"
|
||||||
|
input_arg {name: "placeholder" type: DT_FLOAT}
|
||||||
|
input_arg {name: "placeholder_1" type: DT_RESOURCE}
|
||||||
|
input_arg {name: "readvariableop_x" type: DT_FLOAT}
|
||||||
|
output_arg {name: "readvariableop" type: DT_FLOAT}
|
||||||
|
is_stateful: true
|
||||||
|
}
|
||||||
|
|
||||||
|
node_def {name: "ReadVariableOp" op: "Identity"
|
||||||
|
input: "readvariableop_x"}}
|
||||||
|
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "case_cond_false_frozen_0"
|
||||||
|
input_arg {name: "placeholder" type: DT_FLOAT}
|
||||||
|
input_arg {name: "less_readvariableop_1_y" type: DT_RESOURCE}
|
||||||
|
input_arg {name: "less_readvariableop_z" type: DT_FLOAT}
|
||||||
|
output_arg {name: "case_cond_identity" type: DT_FLOAT}
|
||||||
|
is_stateful: true
|
||||||
|
}
|
||||||
|
|
||||||
|
node_def {name: "Less/ReadVariableOp_1" op: "ReadVariableOp"
|
||||||
|
input: "less_readvariableop_1_y"}
|
||||||
|
|
||||||
|
node_def {name: "Less/ReadVariableOp" op: "Identity"
|
||||||
|
input: "less_readvariableop_z"}
|
||||||
|
|
||||||
|
node_def {name: "case/cond" op: "If"
|
||||||
|
input: "less_readvariableop_z"
|
||||||
|
input: "less_readvariableop_1_y"
|
||||||
|
attr {
|
||||||
|
key: "Tin"
|
||||||
|
value {list {type: DT_FLOAT type: DT_RESOURCE}}}
|
||||||
|
attr {key: "then_branch"
|
||||||
|
value {func {name: "case_cond_true_frozen_1"}}}
|
||||||
|
attr {key: "else_branch"
|
||||||
|
value {func {name: "case_cond_false_frozen_1"}}}
|
||||||
|
attr {
|
||||||
|
key: "_read_only_resource_inputs"
|
||||||
|
value {list {i: 1 i: 2}}}}}
|
||||||
|
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "case_cond_false_frozen_1"
|
||||||
|
input_arg {name: "placeholder" type: DT_FLOAT}
|
||||||
|
input_arg {name: "readvariableop_y" type: DT_RESOURCE}
|
||||||
|
output_arg {name: "readvariableop" type: DT_FLOAT}
|
||||||
|
is_stateful: true
|
||||||
|
}
|
||||||
|
|
||||||
|
node_def {name: "ReadVariableOp" op: "ReadVariableOp"
|
||||||
|
input: "readvariableop_y"}}
|
||||||
|
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "case_cond_true_frozen_1"
|
||||||
|
input_arg {name: "placeholder" type: DT_RESOURCE}
|
||||||
|
input_arg {name: "readvariableop_z" type: DT_FLOAT}
|
||||||
|
output_arg {name: "readvariableop" type: DT_FLOAT}
|
||||||
|
is_stateful: true
|
||||||
|
}
|
||||||
|
|
||||||
|
node_def {name: "ReadVariableOp" op: "Identity"
|
||||||
|
input: "readvariableop_z"}}}""")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -23,16 +23,19 @@ import re
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.framework import node_def_pb2
|
from tensorflow.core.framework import node_def_pb2
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
|
from tensorflow.python.util import lazy_loader
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
# A normal import here would generate circular dependencies.
|
||||||
|
convert_to_constants = lazy_loader.LazyLoader(
|
||||||
|
"convert_to_constants", globals(),
|
||||||
|
"tensorflow.python.framework.convert_to_constants")
|
||||||
|
|
||||||
_VARIABLE_OPS = {
|
_VARIABLE_OPS = {
|
||||||
"Assign",
|
"Assign",
|
||||||
"AssignAdd",
|
"AssignAdd",
|
||||||
@ -237,76 +240,6 @@ def tensor_shape_from_node_def_name(graph, input_name):
|
|||||||
return shape
|
return shape
|
||||||
|
|
||||||
|
|
||||||
def _update_resource_identities(resource_identities, output_graph_def,
|
|
||||||
variable_names_whitelist,
|
|
||||||
variable_names_blacklist):
|
|
||||||
"""Updates the type of DT_RESOURCE Identity ops.
|
|
||||||
|
|
||||||
Updates the type of the `resource_identities` to the type of the node that
|
|
||||||
feed into it if the node is not an input to any other node. Valid nodes are
|
|
||||||
generally colocated nodes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resource_identities: List of NodeDef protos that are Identity ops with the
|
|
||||||
type DT_RESOURCE.
|
|
||||||
output_graph_def: GraphDef proto.
|
|
||||||
variable_names_whitelist: The set of variable names to convert (by default,
|
|
||||||
all variables are converted).
|
|
||||||
variable_names_blacklist: The set of variable names to omit converting
|
|
||||||
to constants.
|
|
||||||
"""
|
|
||||||
# Identify the nodes in the graph and the nodes consuming each node.
|
|
||||||
map_name_to_node = {}
|
|
||||||
map_name_to_inputs = {}
|
|
||||||
for node in output_graph_def.node:
|
|
||||||
map_name_to_node[node.name] = node
|
|
||||||
for unparsed_input_name in node.input:
|
|
||||||
if not unparsed_input_name.startswith("^"):
|
|
||||||
parsed_input_name = _node_name(unparsed_input_name)
|
|
||||||
if parsed_input_name not in map_name_to_inputs:
|
|
||||||
map_name_to_inputs[parsed_input_name] = []
|
|
||||||
map_name_to_inputs[parsed_input_name].append(node.name)
|
|
||||||
|
|
||||||
for node in resource_identities:
|
|
||||||
# Validate the node is not an input to other nodes.
|
|
||||||
if node.name in map_name_to_inputs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the type of the Identity node by tracing back through the nodes until
|
|
||||||
# we come to a non-Identity or non-control flow node or the type of the node
|
|
||||||
# is not DT_RESOURCE.
|
|
||||||
input_node = map_name_to_node[_node_name(node.input[0])]
|
|
||||||
while (input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY and
|
|
||||||
input_node.attr["T"].type == dtypes.resource):
|
|
||||||
input_node = map_name_to_node[_node_name(input_node.input[0])]
|
|
||||||
|
|
||||||
# Update the type of the Identity node if an Identity, control flow, or
|
|
||||||
# VarHandleOp node with a type that is not DT_RESOURCE is found.
|
|
||||||
debugging_message = str.encode(
|
|
||||||
"This Identity's type was changed from DT_RESOURCE during graph "
|
|
||||||
"freezing.")
|
|
||||||
if input_node.attr["T"].type != dtypes.resource:
|
|
||||||
if (input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY
|
|
||||||
and _should_convert(
|
|
||||||
input_node.input[0],
|
|
||||||
variable_names_whitelist,
|
|
||||||
variable_names_blacklist)):
|
|
||||||
node.attr["T"].CopyFrom(input_node.attr["T"])
|
|
||||||
node.attr["_debugging"].s = debugging_message
|
|
||||||
elif (input_node.op == "VarHandleOp"
|
|
||||||
and _should_convert(
|
|
||||||
input_node.name,
|
|
||||||
variable_names_whitelist,
|
|
||||||
variable_names_blacklist)):
|
|
||||||
node.attr["T"].CopyFrom(input_node.attr["dtype"])
|
|
||||||
node.attr["_debugging"].s = debugging_message
|
|
||||||
|
|
||||||
|
|
||||||
def _should_convert(name, whitelist, blacklist):
|
|
||||||
return ((whitelist is None or name in whitelist)
|
|
||||||
and (blacklist is None or name not in blacklist))
|
|
||||||
|
|
||||||
|
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
date=None,
|
date=None,
|
||||||
instructions="Use `tf.compat.v1.graph_util.convert_variables_to_constants`")
|
instructions="Use `tf.compat.v1.graph_util.convert_variables_to_constants`")
|
||||||
@ -339,190 +272,16 @@ def convert_variables_to_constants(sess,
|
|||||||
RuntimeError: if a DT_RESOURCE op is found whose ancestor Variables are both
|
RuntimeError: if a DT_RESOURCE op is found whose ancestor Variables are both
|
||||||
blacklisted AND whitelisted for freezing.
|
blacklisted AND whitelisted for freezing.
|
||||||
"""
|
"""
|
||||||
|
ret = convert_to_constants.convert_variables_to_constants_from_session_graph(
|
||||||
get_input_name = lambda node, index=0: node.input[index].split(":")[0]
|
session=sess,
|
||||||
|
graph_def=input_graph_def,
|
||||||
def create_const_op(node_name, dtype, data, data_shape=None):
|
output_node_names=output_node_names,
|
||||||
"""Creates a Const op."""
|
variable_names_whitelist=variable_names_whitelist,
|
||||||
output_node = node_def_pb2.NodeDef()
|
variable_names_blacklist=variable_names_blacklist)
|
||||||
output_node.op = "Const"
|
# The previous code logic generated an empty versions field, we clear it here
|
||||||
output_node.name = node_name
|
# to maintain backwards compatibility.
|
||||||
output_node.attr["dtype"].CopyFrom(dtype)
|
ret.versions.Clear()
|
||||||
output_node.attr["value"].CopyFrom(
|
return ret
|
||||||
attr_value_pb2.AttrValue(
|
|
||||||
tensor=tensor_util.make_tensor_proto(
|
|
||||||
data, dtype=dtype.type, shape=data_shape)))
|
|
||||||
return output_node
|
|
||||||
|
|
||||||
# This graph only includes the nodes needed to evaluate the output nodes, and
|
|
||||||
# removes unneeded nodes like those involved in saving and assignment.
|
|
||||||
inference_graph = extract_sub_graph(input_graph_def, output_node_names)
|
|
||||||
|
|
||||||
# Identify the ops in the graph.
|
|
||||||
map_name_to_node = {
|
|
||||||
node.name: node for node in inference_graph.node
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get list of variables.
|
|
||||||
variable_names = []
|
|
||||||
variable_dict_names = []
|
|
||||||
resource_op_types = {}
|
|
||||||
|
|
||||||
for node in inference_graph.node:
|
|
||||||
if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
|
|
||||||
variable_name = node.name
|
|
||||||
if not _should_convert(
|
|
||||||
variable_name, variable_names_whitelist, variable_names_blacklist):
|
|
||||||
continue
|
|
||||||
variable_dict_names.append(variable_name)
|
|
||||||
if node.op == "VarHandleOp":
|
|
||||||
variable_names.append(variable_name + "/Read/ReadVariableOp:0")
|
|
||||||
else:
|
|
||||||
variable_names.append(variable_name + ":0")
|
|
||||||
elif node.op in ["ReadVariableOp", "ResourceGather", "ResourceGatherNd"]:
|
|
||||||
# There can be one or more Identity or control flow ops in between the
|
|
||||||
# ReadVariableOp and VarHandleOp. Store the ops with the associated
|
|
||||||
# dtypes.
|
|
||||||
source_op_names = [get_input_name(node)]
|
|
||||||
candidate_resource_op_types = {}
|
|
||||||
while (source_op_names and map_name_to_node[source_op_names[0]].op in
|
|
||||||
_CONTROL_FLOW_OP_NAMES_OR_IDENTITY):
|
|
||||||
source_op_name = source_op_names.pop()
|
|
||||||
current_node = map_name_to_node[source_op_name]
|
|
||||||
|
|
||||||
if (source_op_name not in resource_op_types and
|
|
||||||
source_op_name not in candidate_resource_op_types):
|
|
||||||
candidate_resource_op_types[source_op_name] = node.attr["dtype"]
|
|
||||||
source_op_names.append(get_input_name(current_node))
|
|
||||||
|
|
||||||
if current_node == "Merge":
|
|
||||||
merge_resource_name = get_input_name(current_node, index=1)
|
|
||||||
if (merge_resource_name not in resource_op_types
|
|
||||||
and merge_resource_name not in candidate_resource_op_types):
|
|
||||||
candidate_resource_op_types[merge_resource_name] = (
|
|
||||||
node.attr["dtype"])
|
|
||||||
source_op_names.append(
|
|
||||||
get_input_name(map_name_to_node[merge_resource_name]))
|
|
||||||
|
|
||||||
should_convert_all = None
|
|
||||||
for source_node in source_op_names:
|
|
||||||
if map_name_to_node[source_node].op != "VarHandleOp":
|
|
||||||
raise ValueError("Cannot find the variable that is an input "
|
|
||||||
"to the ReadVariableOp.")
|
|
||||||
should_convert_node = _should_convert(
|
|
||||||
source_node, variable_names_whitelist, variable_names_blacklist)
|
|
||||||
if should_convert_all is None:
|
|
||||||
should_convert_all = should_convert_node
|
|
||||||
elif should_convert_all != should_convert_node:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Found DT_RESOURCE node whose ancestor Variables are both "
|
|
||||||
"blacklisted AND whitelisted for freezing. Originating "
|
|
||||||
"descendant node: {}. Ancestor variables: {}.".format(
|
|
||||||
node.name, source_op_names))
|
|
||||||
if should_convert_all in (None, True):
|
|
||||||
resource_op_types.update(candidate_resource_op_types)
|
|
||||||
|
|
||||||
# Gets map of variables and the associated data.
|
|
||||||
if variable_names:
|
|
||||||
returned_variables = sess.run(variable_names)
|
|
||||||
else:
|
|
||||||
returned_variables = []
|
|
||||||
variables_data_map = dict(zip(variable_dict_names, returned_variables))
|
|
||||||
logging.info("Froze %d variables.", len(returned_variables))
|
|
||||||
|
|
||||||
def _should_convert_ancestor(node):
|
|
||||||
input_node = map_name_to_node[_node_name(node.input[0])]
|
|
||||||
while (input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY and
|
|
||||||
input_node.attr["T"].type == dtypes.resource):
|
|
||||||
input_node = map_name_to_node[_node_name(input_node.input[0])]
|
|
||||||
return _should_convert(input_node.name,
|
|
||||||
variable_names_whitelist,
|
|
||||||
variable_names_blacklist)
|
|
||||||
|
|
||||||
# Reconstruct the graph with constants in place of variables.
|
|
||||||
output_graph_def = graph_pb2.GraphDef()
|
|
||||||
how_many_converted = 0
|
|
||||||
for input_node in inference_graph.node:
|
|
||||||
output_node = node_def_pb2.NodeDef()
|
|
||||||
if input_node.name in variables_data_map:
|
|
||||||
data = variables_data_map[input_node.name]
|
|
||||||
output_node = create_const_op(input_node.name, input_node.attr["dtype"],
|
|
||||||
data, data.shape)
|
|
||||||
how_many_converted += 1
|
|
||||||
elif input_node.name in resource_op_types:
|
|
||||||
# Converts the type of the ops between the ReadVariableOp and VarHandleOp
|
|
||||||
# from RESOURCE_DT to the appropriate type based on the input they are
|
|
||||||
# referencing. Do not copy shapes due to incorrect shape info.
|
|
||||||
output_node.op = input_node.op
|
|
||||||
output_node.name = input_node.name
|
|
||||||
for in_node in input_node.input:
|
|
||||||
output_node.input.append(in_node)
|
|
||||||
for attr_name in input_node.attr:
|
|
||||||
if str(attr_name) != "_output_shapes":
|
|
||||||
output_node.attr[attr_name].CopyFrom(input_node.attr[attr_name])
|
|
||||||
output_node.attr["T"].CopyFrom(resource_op_types[input_node.name])
|
|
||||||
elif (input_node.op == "ReadVariableOp"
|
|
||||||
and _should_convert_ancestor(input_node)):
|
|
||||||
# The first branch converts all VarHandleOps of ResourceVariables to
|
|
||||||
# constants, so we need to convert the associated ReadVariableOps to
|
|
||||||
# Identity ops.
|
|
||||||
output_node.op = "Identity"
|
|
||||||
output_node.name = input_node.name
|
|
||||||
output_node.input.extend([input_node.input[0]])
|
|
||||||
output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
|
|
||||||
if "_class" in input_node.attr:
|
|
||||||
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
|
|
||||||
elif (input_node.op == "ResourceGather"
|
|
||||||
and _should_convert_ancestor(input_node)):
|
|
||||||
# The first branch converts all VarHandleOps of ResourceGather to
|
|
||||||
# constants, so we need to convert the associated ResourceGather to Gather
|
|
||||||
# ops with a Const axis feeding into it.
|
|
||||||
if input_node.attr["batch_dims"].i != 0:
|
|
||||||
raise ValueError("batch_dims != 0 is not supported by freeze_graph.")
|
|
||||||
axis_data = input_node.attr["batch_dims"].i
|
|
||||||
axis_node_name = input_node.name + "/axis"
|
|
||||||
axis_dtype = input_node.attr["Tindices"]
|
|
||||||
output_axis_node = create_const_op(axis_node_name, axis_dtype, axis_data)
|
|
||||||
output_graph_def.node.extend([output_axis_node])
|
|
||||||
|
|
||||||
output_node.op = "GatherV2"
|
|
||||||
output_node.name = input_node.name
|
|
||||||
output_node.input.extend(
|
|
||||||
[input_node.input[0], input_node.input[1], axis_node_name])
|
|
||||||
output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
|
|
||||||
output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
|
|
||||||
output_node.attr["Taxis"].CopyFrom(axis_dtype)
|
|
||||||
if "_class" in input_node.attr:
|
|
||||||
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
|
|
||||||
elif (input_node.op == "ResourceGatherNd"
|
|
||||||
and _should_convert_ancestor(input_node)):
|
|
||||||
output_node.op = "GatherNd"
|
|
||||||
output_node.name = input_node.name
|
|
||||||
output_node.input.extend(
|
|
||||||
[input_node.input[0], input_node.input[1]])
|
|
||||||
output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
|
|
||||||
output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
|
|
||||||
if "_class" in input_node.attr:
|
|
||||||
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
|
|
||||||
else:
|
|
||||||
output_node.CopyFrom(input_node)
|
|
||||||
output_graph_def.node.append(output_node)
|
|
||||||
|
|
||||||
# Update the types of the DT_RESOURCE Identity nodes that do not have an
|
|
||||||
# associated ReadVariableOp.
|
|
||||||
resource_identities = []
|
|
||||||
for node in output_graph_def.node:
|
|
||||||
if node.op == "Identity" and node.attr["T"].type == dtypes.resource:
|
|
||||||
resource_identities.append(node)
|
|
||||||
if resource_identities:
|
|
||||||
_update_resource_identities(resource_identities,
|
|
||||||
output_graph_def,
|
|
||||||
variable_names_whitelist,
|
|
||||||
variable_names_blacklist)
|
|
||||||
|
|
||||||
output_graph_def.library.CopyFrom(inference_graph.library)
|
|
||||||
logging.info("Converted %d variables to const ops.", how_many_converted)
|
|
||||||
return output_graph_def
|
|
||||||
|
|
||||||
|
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
|
@ -21,27 +21,15 @@ from __future__ import print_function
|
|||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.framework import node_def_pb2
|
from tensorflow.core.framework import node_def_pb2
|
||||||
from tensorflow.core.protobuf import config_pb2
|
|
||||||
from tensorflow.core.protobuf import meta_graph_pb2
|
|
||||||
from tensorflow.python.client import session
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import function
|
|
||||||
from tensorflow.python.framework import graph_util
|
from tensorflow.python.framework import graph_util
|
||||||
from tensorflow.python.framework import importer
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.grappler import tf_optimizer
|
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import control_flow_ops
|
|
||||||
from tensorflow.python.ops import gen_math_ops
|
|
||||||
from tensorflow.python.ops import gen_state_ops
|
from tensorflow.python.ops import gen_state_ops
|
||||||
from tensorflow.python.ops import math_ops as math_ops_lib
|
|
||||||
from tensorflow.python.ops import variable_scope
|
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training.saver import export_meta_graph
|
|
||||||
|
|
||||||
|
|
||||||
# Utility device function to use for testing
|
# Utility device function to use for testing
|
||||||
@ -316,203 +304,5 @@ class DeviceFunctionsTest(test.TestCase):
|
|||||||
graph_util.remove_training_nodes(graph_def))
|
graph_util.remove_training_nodes(graph_def))
|
||||||
|
|
||||||
|
|
||||||
class ConvertVariablesToConstantsTest(test.TestCase):
|
|
||||||
|
|
||||||
def _ensure_no_variables_in_graph(self, graph_def):
|
|
||||||
"""Ensures there are no variables in the graph."""
|
|
||||||
for node in graph_def.node:
|
|
||||||
self.assertNotIn(
|
|
||||||
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
|
|
||||||
|
|
||||||
def _test_variable_to_const_conversion(self, use_resource):
|
|
||||||
with ops.Graph().as_default():
|
|
||||||
with variable_scope.variable_scope("", use_resource=use_resource):
|
|
||||||
variable_node = variable_scope.get_variable(
|
|
||||||
"variable_node", initializer=1.0)
|
|
||||||
another_variable = variable_scope.get_variable(
|
|
||||||
"unused_variable_node", initializer=1.0)
|
|
||||||
output_node = math_ops_lib.multiply(
|
|
||||||
variable_node, 2.0, name="output_node")
|
|
||||||
with session.Session() as sess:
|
|
||||||
self.evaluate(variable_node.initializer)
|
|
||||||
output = self.evaluate(output_node)
|
|
||||||
self.assertNear(2.0, output, 0.00001)
|
|
||||||
variable_graph_def = sess.graph.as_graph_def()
|
|
||||||
# First get the constant_graph_def when variable_names_whitelist is
|
|
||||||
# set, note that if variable_names_whitelist is not set an error will
|
|
||||||
# be thrown because unused_variable_node is not initialized.
|
|
||||||
constant_graph_def = graph_util.convert_variables_to_constants(
|
|
||||||
sess,
|
|
||||||
variable_graph_def, ["output_node"],
|
|
||||||
variable_names_whitelist=set(["variable_node"]))
|
|
||||||
|
|
||||||
# Then initialize the unused variable, and get another
|
|
||||||
# constant_graph_def when variable_names_whitelist is not set.
|
|
||||||
self.evaluate(another_variable.initializer)
|
|
||||||
constant_graph_def_without_variable_whitelist = (
|
|
||||||
graph_util.convert_variables_to_constants(
|
|
||||||
sess, variable_graph_def, ["output_node"]))
|
|
||||||
|
|
||||||
# The unused variable should be cleared so the two graphs should be
|
|
||||||
# equivalent.
|
|
||||||
self.assertEqual(
|
|
||||||
str(constant_graph_def),
|
|
||||||
str(constant_graph_def_without_variable_whitelist))
|
|
||||||
|
|
||||||
# Test variable name black list. This should result in the variable
|
|
||||||
# not being a const.
|
|
||||||
constant_graph_def_with_blacklist = (
|
|
||||||
graph_util.convert_variables_to_constants(
|
|
||||||
sess,
|
|
||||||
variable_graph_def, ["output_node"],
|
|
||||||
variable_names_blacklist=set(["variable_node"])))
|
|
||||||
variable_node = None
|
|
||||||
for node in constant_graph_def_with_blacklist.node:
|
|
||||||
if node.name == "variable_node":
|
|
||||||
variable_node = node
|
|
||||||
self.assertIsNotNone(variable_node)
|
|
||||||
if use_resource:
|
|
||||||
self.assertEqual(variable_node.op, "VarHandleOp")
|
|
||||||
else:
|
|
||||||
self.assertEqual(variable_node.op, "VariableV2")
|
|
||||||
|
|
||||||
# Now we make sure the variable is now a constant, and that the graph still
|
|
||||||
# produces the expected result.
|
|
||||||
with ops.Graph().as_default():
|
|
||||||
_ = importer.import_graph_def(constant_graph_def, name="")
|
|
||||||
self.assertEqual(4, len(constant_graph_def.node))
|
|
||||||
self._ensure_no_variables_in_graph(constant_graph_def)
|
|
||||||
with session.Session() as sess:
|
|
||||||
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
|
||||||
output = self.evaluate(output_node)
|
|
||||||
self.assertNear(2.0, output, 0.00001)
|
|
||||||
|
|
||||||
def test_resource_variable_can_be_written_after_blacklisting(self):
|
|
||||||
with ops.Graph().as_default():
|
|
||||||
with variable_scope.variable_scope("", use_resource=True):
|
|
||||||
variable_node = variable_scope.get_variable(
|
|
||||||
"variable_node", initializer=1.0)
|
|
||||||
another_variable = variable_scope.get_variable(
|
|
||||||
"unused_variable_node", initializer=2.0)
|
|
||||||
with ops.control_dependencies([
|
|
||||||
variable_node.assign(another_variable + variable_node)]):
|
|
||||||
output_node = array_ops.identity(variable_node, name="output_node")
|
|
||||||
initializer_name = variable_node.initializer.name
|
|
||||||
with session.Session() as sess:
|
|
||||||
self.evaluate(variable_node.initializer)
|
|
||||||
self.evaluate(another_variable.initializer)
|
|
||||||
output = self.evaluate(output_node)
|
|
||||||
self.assertNear(3.0, output, 0.00001)
|
|
||||||
variable_graph_def = sess.graph.as_graph_def()
|
|
||||||
|
|
||||||
# Test variable name black list. This should result in the variable
|
|
||||||
# not being a const. Furthermore, the paths that read from and assign
|
|
||||||
# to the blacklisted variable should continue to be valid.
|
|
||||||
constant_graph_def_with_blacklist = (
|
|
||||||
graph_util.convert_variables_to_constants(
|
|
||||||
sess,
|
|
||||||
variable_graph_def, ["output_node", initializer_name],
|
|
||||||
variable_names_blacklist=set(["variable_node"])))
|
|
||||||
|
|
||||||
variable_node = None
|
|
||||||
for node in constant_graph_def_with_blacklist.node:
|
|
||||||
if node.name == "variable_node":
|
|
||||||
variable_node = node
|
|
||||||
self.assertIsNotNone(variable_node)
|
|
||||||
self.assertEqual(variable_node.op, "VarHandleOp")
|
|
||||||
|
|
||||||
# Now we make sure another_variable is now a constant, but the original
|
|
||||||
# variable is not, and that the graph can be executed and update the
|
|
||||||
# variable can be updated with each execution.
|
|
||||||
with ops.Graph().as_default():
|
|
||||||
_ = importer.import_graph_def(constant_graph_def_with_blacklist, name="")
|
|
||||||
with session.Session() as sess:
|
|
||||||
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
|
||||||
self.evaluate(sess.graph.get_operation_by_name(initializer_name))
|
|
||||||
output = self.evaluate(output_node)
|
|
||||||
self.assertNear(3.0, output, 0.00001)
|
|
||||||
output = self.evaluate(output_node)
|
|
||||||
self.assertNear(5.0, output, 0.00001)
|
|
||||||
|
|
||||||
def _inline_functions(self, graph_def, arrays):
|
|
||||||
meta_graph = export_meta_graph(graph_def=graph_def)
|
|
||||||
fetch_collection = meta_graph_pb2.CollectionDef()
|
|
||||||
for name in arrays:
|
|
||||||
fetch_collection.node_list.value.append(name)
|
|
||||||
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
|
|
||||||
|
|
||||||
# Initialize RewriterConfig with everything disabled except function
|
|
||||||
# inlining.
|
|
||||||
config = config_pb2.ConfigProto()
|
|
||||||
rewrite_options = config.graph_options.rewrite_options
|
|
||||||
rewrite_options.optimizers.append("function")
|
|
||||||
return tf_optimizer.OptimizeGraph(config, meta_graph)
|
|
||||||
|
|
||||||
def _test_convert_variables_with_functions(self, inline_functions):
|
|
||||||
"""Freezes a graph with functions."""
|
|
||||||
|
|
||||||
@function.Defun(dtypes.float32)
|
|
||||||
def plus_one(x):
|
|
||||||
return x + 1.0
|
|
||||||
|
|
||||||
with ops.Graph().as_default():
|
|
||||||
variable_node = variables.Variable(1.0, name="variable_node")
|
|
||||||
_ = variables.Variable(1.0, name="unused_variable_node")
|
|
||||||
defun_node = plus_one(variable_node)
|
|
||||||
_ = math_ops_lib.multiply(defun_node, 2.0, name="output_node")
|
|
||||||
|
|
||||||
with session.Session() as sess:
|
|
||||||
self.evaluate(variables.variables_initializer([variable_node]))
|
|
||||||
variable_graph_def = sess.graph.as_graph_def()
|
|
||||||
|
|
||||||
if inline_functions:
|
|
||||||
# Run Grappler to create the VarOpHandle --> Placeholder -->
|
|
||||||
# ResourceVariable pattern.
|
|
||||||
variable_graph_def = self._inline_functions(
|
|
||||||
variable_graph_def, ["variable_node", "output_node"])
|
|
||||||
|
|
||||||
constant_graph_def = graph_util.convert_variables_to_constants(
|
|
||||||
sess, variable_graph_def, ["output_node"])
|
|
||||||
|
|
||||||
self._ensure_no_variables_in_graph(constant_graph_def)
|
|
||||||
|
|
||||||
def testReferenceVariables(self):
|
|
||||||
"""Freezes a graph with reference variables."""
|
|
||||||
self._test_variable_to_const_conversion(use_resource=False)
|
|
||||||
|
|
||||||
def testResourceVariables(self):
|
|
||||||
"""Freezes a graph with resource variables."""
|
|
||||||
self._test_variable_to_const_conversion(use_resource=True)
|
|
||||||
|
|
||||||
def testWithFunctions(self):
|
|
||||||
"""Freezes a graph with functions."""
|
|
||||||
self._test_convert_variables_with_functions(inline_functions=False)
|
|
||||||
|
|
||||||
def testWithInlinedFunctions(self):
|
|
||||||
"""Freezes a graph with functions that have been inlined using Grappler."""
|
|
||||||
self._test_convert_variables_with_functions(inline_functions=True)
|
|
||||||
|
|
||||||
def testGraphWithSwitch(self):
|
|
||||||
"""Freezes a graph which contains a Switch with type RESOURCE_DT."""
|
|
||||||
with ops.Graph().as_default():
|
|
||||||
with variable_scope.variable_scope("", use_resource=True):
|
|
||||||
x = variable_scope.get_variable("var_x", initializer=1.0)
|
|
||||||
y = variable_scope.get_variable("var_y", initializer=2.0)
|
|
||||||
f1 = lambda: variable_scope.get_variable("var_f1", initializer=17.0)
|
|
||||||
f2 = lambda: variable_scope.get_variable("var_f2", initializer=23.0)
|
|
||||||
cond_node = control_flow_ops.case([(gen_math_ops.less(x, y), f1)],
|
|
||||||
default=f2)
|
|
||||||
_ = math_ops_lib.multiply(cond_node, 2.0, name="output_node")
|
|
||||||
|
|
||||||
with session.Session() as sess:
|
|
||||||
sess.run(variables.global_variables_initializer())
|
|
||||||
variable_graph_def = sess.graph.as_graph_def()
|
|
||||||
|
|
||||||
constant_graph_def = graph_util.convert_variables_to_constants(
|
|
||||||
sess, variable_graph_def, ["output_node"])
|
|
||||||
|
|
||||||
self._ensure_no_variables_in_graph(constant_graph_def)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user