Unify convert to constants logic.

Before this change there were two different code paths for dealing
with graph freezing for v1 and v2 graphs. They largely did the same
thing, but each path had and lacked capabilities the other had, and
each had its own bugs.

This change re-writes the previous v2 logic so it can cope with
session-based graphs and allow for the conversion of a subset of the
variables, and changes the previous convert_variables_to_constants
call to proxy into the new logic.

The new logic is built around a more "graphy" algorithm: variables are
converted to constants, and that conversion is then propagated through
the graph by following the graph edges. This hopefully makes it easier
to understand what is going on, and to change it later on.

More granular tests were added, in order to check that the right graph
manipulations were performed. In order to do that, some graph merging
infrastructure had to be created in the test.

PiperOrigin-RevId: 315497124
Change-Id: I3a33acc804b5dc9628c208df8fd1b7c59f906ddb
This commit is contained in:
Cesar Crusius 2020-06-09 09:19:09 -07:00 committed by TensorFlower Gardener
parent 75f09e6973
commit ed52277026
5 changed files with 1694 additions and 946 deletions

View File

@ -6633,6 +6633,7 @@ tf_py_test(
deps = [ deps = [
":client", ":client",
":client_testlib", ":client_testlib",
":control_flow_v2_toggles",
":framework", ":framework",
":framework_for_generated_wrappers", ":framework_for_generated_wrappers",
":math_ops", ":math_ops",
@ -6650,9 +6651,11 @@ tf_py_test(
python_version = "PY3", python_version = "PY3",
tags = ["no_rocm"], tags = ["no_rocm"],
deps = [ deps = [
"client_testlib", ":client_testlib",
"framework_test_lib", ":control_flow_v2_toggles",
":convert_to_constants", ":convert_to_constants",
":framework_test_lib",
":math_ops",
], ],
) )

File diff suppressed because it is too large Load Diff

View File

@ -19,33 +19,129 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import re
import numpy as np import numpy as np
from google.protobuf import text_format
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session as session_lib from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import convert_to_constants from tensorflow.python.framework import convert_to_constants
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2 from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops import while_v2 from tensorflow.python.ops import while_v2
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model import simple_save from tensorflow.python.saved_model import simple_save
from tensorflow.python.saved_model.load import load from tensorflow.python.saved_model.load import load
from tensorflow.python.saved_model.save import save from tensorflow.python.saved_model.save import save
from tensorflow.python.training.saver import export_meta_graph
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest from tensorflow.python.util import nest
class _GraphMerger(object):
"""GraphDef merging methods for testing purposes."""
@staticmethod
def merge_any(x1, x2, empty_fn):
"""Merges two values using the message's CopyFrom/MergeFrom methods."""
merged = empty_fn()
merged.CopyFrom(x1)
merged.MergeFrom(x2)
return merged
@staticmethod
def merge_nodes(node1, node2):
"""Merges two NodeDef messages."""
merged = _GraphMerger.merge_any(node1, node2, node_def_pb2.NodeDef)
merged_inputs = node1.input[:]
merged_inputs.extend([i for i in node2.input[:] if i not in merged_inputs])
merged.input[:] = merged_inputs
return merged
@staticmethod
def merge_lists(repeated1, repeated2, empty_fn, key_fn, merge_fn):
"""Merges two lists representing maps."""
merged = {}
xs1 = {key_fn(x): x for x in repeated1}
xs2 = {key_fn(x): x for x in repeated2}
for name in set().union(xs1.keys(), xs2.keys()):
x1 = empty_fn() if name not in xs1 else xs1[name]
x2 = empty_fn() if name not in xs2 else xs2[name]
merged[name] = merge_fn(x1, x2)
return sorted(merged.values(), key=key_fn)
@staticmethod
def merge_node_lists(repeated_nodes1, repeated_nodes2):
"""Merges two repeated node fields."""
return _GraphMerger.merge_lists(repeated_nodes1, repeated_nodes2,
node_def_pb2.NodeDef, lambda n: n.name,
_GraphMerger.merge_nodes)
@staticmethod
def merge_functions(fn1, fn2):
"""Merges two FunctionDefs."""
merged = _GraphMerger.merge_any(fn1, fn2, function_pb2.FunctionDef)
del merged.signature.input_arg[:]
merged.signature.input_arg.extend(
_GraphMerger.merge_lists(
fn1.signature.input_arg[:], fn2.signature.input_arg[:],
op_def_pb2.OpDef.ArgDef, lambda a: a.name,
lambda x, y: _GraphMerger.merge_any(x, y, op_def_pb2.OpDef.ArgDef)))
del merged.signature.output_arg[:]
merged.signature.output_arg.extend(
_GraphMerger.merge_lists(
fn1.signature.output_arg[:], fn2.signature.output_arg[:],
op_def_pb2.OpDef.ArgDef, lambda a: a.name,
lambda x, y: _GraphMerger.merge_any(x, y, op_def_pb2.OpDef.ArgDef)))
del merged.node_def[:]
merged.node_def.extend(
_GraphMerger.merge_node_lists(fn1.node_def[:], fn2.node_def[:]))
return merged
@staticmethod
def merge_graphs(graph1, graph2):
"""Merges two GraphDef messages."""
merged = graph_pb2.GraphDef()
merged.node.extend(
_GraphMerger.merge_node_lists(graph1.node[:], graph2.node[:]))
merged.library.function.extend(
_GraphMerger.merge_lists(graph1.library.function,
graph2.library.function,
function_pb2.FunctionDef,
lambda f: f.signature.name,
_GraphMerger.merge_functions))
return merged
class VariablesToConstantsTest(test.TestCase): class VariablesToConstantsTest(test.TestCase):
def _freezeModel(self, model): def _freezeModel(self, model):
@ -325,6 +421,7 @@ class VariablesToConstantsTest(test.TestCase):
cell, seq, dtype=dtypes.float32, sequence_length=[1]) cell, seq, dtype=dtypes.float32, sequence_length=[1])
root, output_func = self._freezeModel(model) root, output_func = self._freezeModel(model)
self._testConvertedFunction(root, root.f, output_func, input_data) self._testConvertedFunction(root, root.f, output_func, input_data)
@test_util.run_v2_only @test_util.run_v2_only
@ -347,6 +444,7 @@ class VariablesToConstantsTest(test.TestCase):
return control_flow_ops.while_loop(condition, body, [x]) return control_flow_ops.while_loop(condition, body, [x])
root, output_func = self._freezeModel(model) root, output_func = self._freezeModel(model)
self._testConvertedFunction(root, root.f, output_func, input_data) self._testConvertedFunction(root, root.f, output_func, input_data)
@test_util.run_v2_only @test_util.run_v2_only
@ -389,5 +487,665 @@ class VariablesToConstantsTest(test.TestCase):
self._testConvertedFunction(root, root.f, output_func, input_data) self._testConvertedFunction(root, root.f, output_func, input_data)
class ConvertVariablesToConstantsSessionTest(test.TestCase):
def _assertGraphContains(self, graph, subgraph):
"""Asserts that the given subgraph is contained within the given graph."""
def normalize_uids(msg):
"""Replace auto-id function names with something consistent."""
# These functions have non-deterministic names, the non-determinism coming
# from having an ops.uid() suffix in their names. We're replacing these
# with new sequential IDs starting from 0 for each prefix, which is
# is sufficient for tests.
if isinstance(msg, graph_pb2.GraphDef):
msg = text_format.MessageToString(msg)
name_prefixes = ["case_cond_true.*", "case_cond_false.*"]
name_regex = r"\b(" + "|".join(name_prefixes) + r")_([0-9]+)\b"
names = {}
for (name, index) in re.findall(name_regex, msg):
names.setdefault(name, set()).add(int(index))
for name, indices in names.items():
for new_index, old_index in enumerate(sorted(list(indices))):
msg = re.sub(r"\b" + name + "_" + str(old_index) + r"\b",
name + "_" + str(new_index), msg)
return msg
norm_graph = text_format.Parse(normalize_uids(graph), graph_pb2.GraphDef())
norm_subgraph = text_format.Parse(
normalize_uids(subgraph), graph_pb2.GraphDef())
# Graph S is contained in C if and only if merge(C,S) == C.
# We merge the input graph with an empty graph to normalize repeated fields:
# assertProtoEquals is sensitive to ordering.
norm_graph = _GraphMerger.merge_graphs(norm_graph, graph_pb2.GraphDef())
merged_graph = _GraphMerger.merge_graphs(norm_graph, norm_subgraph)
self.assertProtoEquals(norm_graph, merged_graph)
def _ensure_no_variables_in_graph(self, graph_def):
"""Ensures there are no variables in the graph."""
for node in graph_def.node:
self.assertNotIn(
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
def _test_variable_to_const_conversion(self, use_resource):
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=use_resource):
variable_node = variable_scope.get_variable(
"variable_node", initializer=1.0)
variable_scope.get_variable("unused_variable_node", initializer=1.0)
output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
with session_lib.Session() as sess:
self.evaluate(variable_node.initializer)
output = self.evaluate(output_node)
self.assertNear(2.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
session=sess,
graph_def=variable_graph_def,
output_node_names=["output_node"]))
self._ensure_no_variables_in_graph(constant_graph_def)
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.
with ops.Graph().as_default():
_ = importer.import_graph_def(constant_graph_def, name="")
self.assertEqual(4, len(constant_graph_def.node))
self._ensure_no_variables_in_graph(constant_graph_def)
with session_lib.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
output = self.evaluate(output_node)
self.assertNear(2.0, output, 0.00001)
def test_resource_variable_can_be_written_after_blacklisting(self):
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
variable_node = variable_scope.get_variable(
"variable_node", initializer=1.0)
another_variable = variable_scope.get_variable(
"unused_variable_node", initializer=2.0)
with ops.control_dependencies(
[variable_node.assign(another_variable + variable_node)]):
output_node = array_ops.identity(variable_node, name="output_node")
initializer_name = variable_node.initializer.name
with session_lib.Session() as sess:
self.evaluate(variable_node.initializer)
self.evaluate(another_variable.initializer)
output = self.evaluate(output_node)
self.assertNear(3.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# Test variable name black list. This should result in the variable
# not being a const. Furthermore, the paths that read from and assign
# to the blacklisted variable should continue to be valid.
constant_graph_def_with_blacklist = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
session=sess,
graph_def=variable_graph_def,
output_node_names=["output_node", initializer_name],
variable_names_blacklist=set(["variable_node"])))
variable_node = None
for node in constant_graph_def_with_blacklist.node:
if node.name == "variable_node":
variable_node = node
self.assertIsNotNone(variable_node)
self.assertEqual(variable_node.op, "VarHandleOp")
# Now we make sure another_variable is now a constant, but the original
# variable is not, and that the graph can be executed and update the
# variable can be updated with each execution.
with ops.Graph().as_default():
_ = importer.import_graph_def(constant_graph_def_with_blacklist, name="")
with session_lib.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
self.evaluate(sess.graph.get_operation_by_name(initializer_name))
output = self.evaluate(output_node)
self.assertNear(3.0, output, 0.00001)
output = self.evaluate(output_node)
self.assertNear(5.0, output, 0.00001)
def _inline_functions(self, graph_def, arrays):
meta_graph = export_meta_graph(graph_def=graph_def)
fetch_collection = meta_graph_pb2.CollectionDef()
for name in arrays:
fetch_collection.node_list.value.append(name)
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
# Initialize RewriterConfig with everything disabled except function
# inlining.
config = config_pb2.ConfigProto()
rewrite_options = config.graph_options.rewrite_options
rewrite_options.optimizers.append("function")
return tf_optimizer.OptimizeGraph(config, meta_graph)
def _test_convert_variables_with_functions(self, inline_functions):
"""Freezes a graph with functions."""
@function.Defun(dtypes.float32)
def plus_one(x):
return x + 1.0
with ops.Graph().as_default():
variable_node = variables.Variable(1.0, name="variable_node")
_ = variables.Variable(1.0, name="unused_variable_node")
defun_node = plus_one(variable_node)
_ = math_ops.multiply(defun_node, 2.0, name="output_node")
with session_lib.Session() as sess:
self.evaluate(variables.variables_initializer([variable_node]))
variable_graph_def = sess.graph.as_graph_def()
if inline_functions:
# Run Grappler to create the VarOpHandle --> Placeholder -->
# ResourceVariable pattern.
variable_graph_def = self._inline_functions(
variable_graph_def, ["variable_node", "output_node"])
constant_graph_def = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
session=sess,
graph_def=variable_graph_def,
output_node_names=["output_node"]))
self._ensure_no_variables_in_graph(constant_graph_def)
def testReferenceVariables(self):
"""Freezes a graph with reference variables."""
self._test_variable_to_const_conversion(use_resource=False)
def testResourceVariables(self):
"""Freezes a graph with resource variables."""
self._test_variable_to_const_conversion(use_resource=True)
def testWithFunctions(self):
"""Freezes a graph with functions."""
self._test_convert_variables_with_functions(inline_functions=False)
def testWithInlinedFunctions(self):
"""Freezes a graph with functions that have been inlined using Grappler."""
self._test_convert_variables_with_functions(inline_functions=True)
def testGraphWithSwitch(self):
"""Freezes a graph which contains a Switch with type RESOURCE_DT."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
x = variable_scope.get_variable("var_x", initializer=1.0)
y = variable_scope.get_variable("var_y", initializer=2.0)
f1 = lambda: variable_scope.get_variable("var_f1", initializer=17.0)
f2 = lambda: variable_scope.get_variable("var_f2", initializer=23.0)
cond_node = control_flow_ops.case([(gen_math_ops.less(x, y), f1)],
default=f2)
_ = math_ops.multiply(cond_node, 2.0, name="output_node")
with session_lib.Session() as sess:
sess.run(variables.global_variables_initializer())
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
session=sess,
graph_def=variable_graph_def,
output_node_names=["output_node"]))
self._ensure_no_variables_in_graph(constant_graph_def)
def testConvertSingleVariable(self):
"""Tests that a single variable is properly converted to a constant."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=False):
_ = variable_scope.get_variable("x", initializer=1.0)
with session_lib.Session() as sess:
sess.run(variables.global_variables_initializer())
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
sess, variable_graph_def, ["x/read"]))
self._assertGraphContains(
constant_graph_def, """
node {
name: "x" op: "Const"
attr { key: "dtype" value { type: DT_FLOAT } }
attr {
key: "value"
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
}
node {
name: "x/read" op: "Identity" input: "x"
attr { key: "T" value { type: DT_FLOAT } }
}""")
def testConvertSingleResourceVariable(self):
"""Tests that a resource variable is properly converted to a constant."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
_ = variable_scope.get_variable("x", initializer=1.0)
with session_lib.Session() as sess:
sess.run(variables.global_variables_initializer())
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
sess, variable_graph_def, ["x/Read/ReadVariableOp"]))
self._assertGraphContains(
constant_graph_def, """
node {
name: "x" op: "Const"
attr { key: "dtype" value { type: DT_FLOAT } }
attr {
key: "value"
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
}
node {
name: "x/Read/ReadVariableOp" op: "Identity" input: "x"
attr { key: "T" value { type: DT_FLOAT } }
}""")
def testConvertOneVariableOfTwo(self):
"""Tests that one variable can be kept unconverted."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=False):
x = variable_scope.get_variable("x", initializer=1.0)
y = variable_scope.get_variable("y", initializer=1.0)
_ = math_ops.multiply(x, y, name="out")
with session_lib.Session() as sess:
sess.run(variables.global_variables_initializer())
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
sess,
variable_graph_def, ["out"],
variable_names_blacklist=["y"]))
self._assertGraphContains(
constant_graph_def, """
node {
name: "x" op: "Const"
attr { key: "dtype" value { type: DT_FLOAT } }
attr {
key: "value"
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
}
node {
name: "x/read" op: "Identity" input: "x"
attr { key: "T" value { type: DT_FLOAT } }
}
node {
name: "y" op: "VariableV2"
attr { key: "dtype" value { type: DT_FLOAT } }
}
node {
name: "y/read" op: "Identity" input: "y"
attr { key: "T" value { type: DT_FLOAT } }
}
node {
name: "out" op: "Mul" input: "x/read" input: "y/read"
attr {key: "T" value {type: DT_FLOAT}}
}""")
def testConvertOneResourceVariableOfTwo(self):
"""Tests that one variable can be kept unconverted."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
x = variable_scope.get_variable("x", initializer=1.0)
y = variable_scope.get_variable("y", initializer=1.0)
_ = math_ops.multiply(x, y, name="out")
with session_lib.Session() as sess:
sess.run(variables.global_variables_initializer())
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
sess,
variable_graph_def, ["out"],
variable_names_blacklist=["y"]))
self._assertGraphContains(
constant_graph_def, """
node {
name: "x" op: "Const"
attr { key: "dtype" value { type: DT_FLOAT } }
attr {
key: "value"
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
}
node {
name: "y" op: "VarHandleOp"
attr { key: "dtype" value { type: DT_FLOAT } }
}
node {
name: "out/ReadVariableOp" op: "Identity" input: "x"
attr { key: "T" value { type: DT_FLOAT } }
}
node {
name: "out/ReadVariableOp_1" op: "ReadVariableOp" input: "y"
attr { key: "dtype" value { type: DT_FLOAT } }
}
node {
name: "out" op: "Mul"
input: "out/ReadVariableOp" input: "out/ReadVariableOp_1"
attr {key: "T" value {type: DT_FLOAT}}
}""")
def testConvertIdentityChain(self):
"""Tests that a chain of Identity ops is converted properly."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
x = variable_scope.get_variable("x", initializer=1.0)
y = array_ops.identity(x, name="y")
_ = array_ops.identity(y, name="z")
with session_lib.Session() as sess:
sess.run(variables.global_variables_initializer())
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
sess, variable_graph_def, ["z"]))
self._assertGraphContains(
constant_graph_def, """
node {
name: "x" op: "Const"
attr { key: "dtype" value { type: DT_FLOAT } }
attr {
key: "value"
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
}
node {
name: "y/ReadVariableOp" op: "Identity" input: "x"
attr { key: "T" value { type: DT_FLOAT } }
}
node {
name: "y" op: "Identity" input: "y/ReadVariableOp"
attr { key: "T" value { type: DT_FLOAT } }
}
node {
name: "z" op: "Identity" input: "y"
attr { key: "T" value { type: DT_FLOAT } }
}""")
def testConvertCase(self):
"""Tests that a v1 case() construction converts properly."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=False):
control_flow_v2_toggles.disable_control_flow_v2()
x = variable_scope.get_variable("x", initializer=1.0)
y = variable_scope.get_variable("y", initializer=2.0)
_ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: x)],
default=lambda: y)
with session_lib.Session() as sess:
sess.run(variables.global_variables_initializer())
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
sess, variable_graph_def, ["case/cond/Merge"]))
self._assertGraphContains(
constant_graph_def, """
node {
name: "x" op: "Const"
attr { key: "dtype" value { type: DT_FLOAT } }
attr {
key: "value"
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
}
node {
name: "y" op: "Const"
attr { key: "dtype" value { type: DT_FLOAT } }
attr {
key: "value"
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 2 }}}
}
node {name: "x/read" op: "Identity" input: "x"}
node {name: "y/read" op: "Identity" input: "y"}
node {name: "Less" op: "Less" input: "x/read" input: "y/read"}
node {name: "case/cond/pred_id" op: "Identity" input: "Less"}
node {
name: "case/cond/Switch_1" op: "Switch"
input: "case/cond/pred_id" input: "x/read"
}
node {
name: "case/cond/Switch_2" op: "Switch"
input: "case/cond/pred_id" input: "y/read"
}
node {
name: "case/cond/Merge" op: "Merge"
input: "case/cond/Switch_2" input: "case/cond/Switch_1:1"
attr {key: "T" value {type: DT_FLOAT}}
}""")
def testConvertV2Case(self):
"""Tests that a v2 case() converts properly."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=False):
control_flow_v2_toggles.enable_control_flow_v2()
a = variable_scope.get_variable("a", initializer=2.0)
x = variable_scope.get_variable("x", initializer=1.0)
y = variable_scope.get_variable("y", initializer=2.0)
_ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: a)],
default=lambda: y)
control_flow_v2_toggles.disable_control_flow_v2()
with session_lib.Session() as sess:
sess.run(variables.global_variables_initializer())
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
sess, variable_graph_def, ["case/cond"]))
self._assertGraphContains(
constant_graph_def, """
node {
name: "x" op: "Const"
attr { key: "dtype" value { type: DT_FLOAT } }
attr {
key: "value"
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 1 }}}
}
node {
name: "y" op: "Const"
attr { key: "dtype" value { type: DT_FLOAT } }
attr {
key: "value"
value { tensor { dtype: DT_FLOAT tensor_shape{} float_val: 2 }}}
}
node {name: "x/read" op: "Identity" input: "x"}
node {name: "y/read" op: "Identity" input: "y"}
node {name: "Less" op: "Less" input: "x/read" input: "y/read"}
node {
name: "case/cond" op: "StatelessIf"
input: "Less" input: "a/read" input: "y/read"
attr {key: "Tcond" value {type: DT_BOOL}}
attr {key: "Tin" value {list {type: DT_FLOAT type: DT_FLOAT}}}
attr {key: "Tout" value {list {type: DT_FLOAT}}}
}
library {
function {
signature {
name: "case_cond_false_frozen_0"
input_arg {name: "placeholder" type: DT_FLOAT}
input_arg {name: "y_read_0" type: DT_FLOAT}
output_arg {name: "y_read" type: DT_FLOAT}
}
}
function {
signature {
name: "case_cond_true_frozen_0"
input_arg {name: "a_read_0" type: DT_FLOAT}
input_arg {name: "placeholder" type: DT_FLOAT}
output_arg {name: "a_read" type: DT_FLOAT}
}
}
}""")
def testConvertV2ResourceCase(self):
"""Tests that a v2 case() with resource variables converts properly."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
control_flow_v2_toggles.enable_control_flow_v2()
x = variable_scope.get_variable("x", initializer=1.0)
y = variable_scope.get_variable("y", initializer=2.0)
_ = control_flow_ops.case([(gen_math_ops.less(x, y), lambda: x)],
default=lambda: y)
control_flow_v2_toggles.disable_control_flow_v2()
with session_lib.Session() as sess:
sess.run(variables.global_variables_initializer())
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
sess, variable_graph_def, ["case/cond"]))
self._assertGraphContains(
constant_graph_def, """
node {name: "x" op: "Const"}
node {name: "y" op: "Const"}
node {
name: "case/cond" op: "If" input: "Less" input: "x" input: "y"
attr {key: "Tcond" value {type: DT_BOOL}}
attr {key: "Tin" value {list {type: DT_FLOAT type: DT_FLOAT}}}
attr {key: "Tout" value {list {type: DT_FLOAT}}}
}
library {
function {
signature {
name: "case_cond_false_frozen_0"
input_arg {name: "placeholder" type: DT_FLOAT}
input_arg {name: "readvariableop_y" type: DT_FLOAT}
output_arg {name: "readvariableop" type: DT_FLOAT}
}
}
function {
signature {
name: "case_cond_true_frozen_0"
input_arg {name: "placeholder" type: DT_FLOAT}
input_arg {name: "readvariableop_x" type: DT_FLOAT}
output_arg {name: "readvariableop" type: DT_FLOAT}
}
}
}""")
def testConvertV2UnconvertedResourceNestedCase(self):
"""Tests unconverted variable propagation through nested functions."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
control_flow_v2_toggles.enable_control_flow_v2()
x = variable_scope.get_variable("x", initializer=1.0)
y = variable_scope.get_variable("y", initializer=2.0)
z = variable_scope.get_variable("z", initializer=3.0)
# pylint: disable=g-long-lambda
_ = control_flow_ops.case(
[(gen_math_ops.less(x, y), lambda: x)],
default=lambda: control_flow_ops.case(
[(gen_math_ops.less(z, y), lambda: z)], default=lambda: y))
# pylint: enable=g-long-lambda
control_flow_v2_toggles.disable_control_flow_v2()
with session_lib.Session() as sess:
sess.run(variables.global_variables_initializer())
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = (
convert_to_constants
.convert_variables_to_constants_from_session_graph(
sess,
variable_graph_def, ["case/cond"],
variable_names_blacklist=["y"]))
self._assertGraphContains(
constant_graph_def, """
node {name: "x" op: "Const"}
node {name: "y" op: "VarHandleOp"}
node {name: "z" op: "Const"}
node {name: "Less/ReadVariableOp" op: "Identity" input: "x"}
node {name: "Less/ReadVariableOp_1" op: "ReadVariableOp" input: "y"}
node {
name: "case/cond" op: "If"
input: "x" input: "z" input: "y"
attr {
key: "Tin"
value {list
{type: DT_FLOAT type: DT_FLOAT type: DT_RESOURCE}}}
attr {
key: "_read_only_resource_inputs"
value {list {i: 1 i: 2 i: 3}}}
attr {key: "then_branch"
value {func {name: "case_cond_true_frozen_0"}}}
attr {key: "else_branch"
value {func {name: "case_cond_false_frozen_0"}}}
attr {key: "output_shapes" value {list {shape {}}}}
}
library {
function {
signature {
name: "case_cond_true_frozen_0"
input_arg {name: "placeholder" type: DT_FLOAT}
input_arg {name: "placeholder_1" type: DT_RESOURCE}
input_arg {name: "readvariableop_x" type: DT_FLOAT}
output_arg {name: "readvariableop" type: DT_FLOAT}
is_stateful: true
}
node_def {name: "ReadVariableOp" op: "Identity"
input: "readvariableop_x"}}
function {
signature {
name: "case_cond_false_frozen_0"
input_arg {name: "placeholder" type: DT_FLOAT}
input_arg {name: "less_readvariableop_1_y" type: DT_RESOURCE}
input_arg {name: "less_readvariableop_z" type: DT_FLOAT}
output_arg {name: "case_cond_identity" type: DT_FLOAT}
is_stateful: true
}
node_def {name: "Less/ReadVariableOp_1" op: "ReadVariableOp"
input: "less_readvariableop_1_y"}
node_def {name: "Less/ReadVariableOp" op: "Identity"
input: "less_readvariableop_z"}
node_def {name: "case/cond" op: "If"
input: "less_readvariableop_z"
input: "less_readvariableop_1_y"
attr {
key: "Tin"
value {list {type: DT_FLOAT type: DT_RESOURCE}}}
attr {key: "then_branch"
value {func {name: "case_cond_true_frozen_1"}}}
attr {key: "else_branch"
value {func {name: "case_cond_false_frozen_1"}}}
attr {
key: "_read_only_resource_inputs"
value {list {i: 1 i: 2}}}}}
function {
signature {
name: "case_cond_false_frozen_1"
input_arg {name: "placeholder" type: DT_FLOAT}
input_arg {name: "readvariableop_y" type: DT_RESOURCE}
output_arg {name: "readvariableop" type: DT_FLOAT}
is_stateful: true
}
node_def {name: "ReadVariableOp" op: "ReadVariableOp"
input: "readvariableop_y"}}
function {
signature {
name: "case_cond_true_frozen_1"
input_arg {name: "placeholder" type: DT_RESOURCE}
input_arg {name: "readvariableop_z" type: DT_FLOAT}
output_arg {name: "readvariableop" type: DT_FLOAT}
is_stateful: true
}
node_def {name: "ReadVariableOp" op: "Identity"
input: "readvariableop_z"}}}""")
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -23,16 +23,19 @@ import re
import six import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import lazy_loader
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# A normal import here would generate circular dependencies.
convert_to_constants = lazy_loader.LazyLoader(
"convert_to_constants", globals(),
"tensorflow.python.framework.convert_to_constants")
_VARIABLE_OPS = { _VARIABLE_OPS = {
"Assign", "Assign",
"AssignAdd", "AssignAdd",
@ -237,76 +240,6 @@ def tensor_shape_from_node_def_name(graph, input_name):
return shape return shape
def _update_resource_identities(resource_identities, output_graph_def,
variable_names_whitelist,
variable_names_blacklist):
"""Updates the type of DT_RESOURCE Identity ops.
Updates the type of the `resource_identities` to the type of the node that
feed into it if the node is not an input to any other node. Valid nodes are
generally colocated nodes.
Args:
resource_identities: List of NodeDef protos that are Identity ops with the
type DT_RESOURCE.
output_graph_def: GraphDef proto.
variable_names_whitelist: The set of variable names to convert (by default,
all variables are converted).
variable_names_blacklist: The set of variable names to omit converting
to constants.
"""
# Identify the nodes in the graph and the nodes consuming each node.
map_name_to_node = {}
map_name_to_inputs = {}
for node in output_graph_def.node:
map_name_to_node[node.name] = node
for unparsed_input_name in node.input:
if not unparsed_input_name.startswith("^"):
parsed_input_name = _node_name(unparsed_input_name)
if parsed_input_name not in map_name_to_inputs:
map_name_to_inputs[parsed_input_name] = []
map_name_to_inputs[parsed_input_name].append(node.name)
for node in resource_identities:
# Validate the node is not an input to other nodes.
if node.name in map_name_to_inputs:
continue
# Get the type of the Identity node by tracing back through the nodes until
# we come to a non-Identity or non-control flow node or the type of the node
# is not DT_RESOURCE.
input_node = map_name_to_node[_node_name(node.input[0])]
while (input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY and
input_node.attr["T"].type == dtypes.resource):
input_node = map_name_to_node[_node_name(input_node.input[0])]
# Update the type of the Identity node if an Identity, control flow, or
# VarHandleOp node with a type that is not DT_RESOURCE is found.
debugging_message = str.encode(
"This Identity's type was changed from DT_RESOURCE during graph "
"freezing.")
if input_node.attr["T"].type != dtypes.resource:
if (input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY
and _should_convert(
input_node.input[0],
variable_names_whitelist,
variable_names_blacklist)):
node.attr["T"].CopyFrom(input_node.attr["T"])
node.attr["_debugging"].s = debugging_message
elif (input_node.op == "VarHandleOp"
and _should_convert(
input_node.name,
variable_names_whitelist,
variable_names_blacklist)):
node.attr["T"].CopyFrom(input_node.attr["dtype"])
node.attr["_debugging"].s = debugging_message
def _should_convert(name, whitelist, blacklist):
return ((whitelist is None or name in whitelist)
and (blacklist is None or name not in blacklist))
@deprecation.deprecated( @deprecation.deprecated(
date=None, date=None,
instructions="Use `tf.compat.v1.graph_util.convert_variables_to_constants`") instructions="Use `tf.compat.v1.graph_util.convert_variables_to_constants`")
@ -339,190 +272,16 @@ def convert_variables_to_constants(sess,
RuntimeError: if a DT_RESOURCE op is found whose ancestor Variables are both RuntimeError: if a DT_RESOURCE op is found whose ancestor Variables are both
blacklisted AND whitelisted for freezing. blacklisted AND whitelisted for freezing.
""" """
ret = convert_to_constants.convert_variables_to_constants_from_session_graph(
get_input_name = lambda node, index=0: node.input[index].split(":")[0] session=sess,
graph_def=input_graph_def,
def create_const_op(node_name, dtype, data, data_shape=None): output_node_names=output_node_names,
"""Creates a Const op.""" variable_names_whitelist=variable_names_whitelist,
output_node = node_def_pb2.NodeDef() variable_names_blacklist=variable_names_blacklist)
output_node.op = "Const" # The previous code logic generated an empty versions field, we clear it here
output_node.name = node_name # to maintain backwards compatibility.
output_node.attr["dtype"].CopyFrom(dtype) ret.versions.Clear()
output_node.attr["value"].CopyFrom( return ret
attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(
data, dtype=dtype.type, shape=data_shape)))
return output_node
# This graph only includes the nodes needed to evaluate the output nodes, and
# removes unneeded nodes like those involved in saving and assignment.
inference_graph = extract_sub_graph(input_graph_def, output_node_names)
# Identify the ops in the graph.
map_name_to_node = {
node.name: node for node in inference_graph.node
}
# Get list of variables.
variable_names = []
variable_dict_names = []
resource_op_types = {}
for node in inference_graph.node:
if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
variable_name = node.name
if not _should_convert(
variable_name, variable_names_whitelist, variable_names_blacklist):
continue
variable_dict_names.append(variable_name)
if node.op == "VarHandleOp":
variable_names.append(variable_name + "/Read/ReadVariableOp:0")
else:
variable_names.append(variable_name + ":0")
elif node.op in ["ReadVariableOp", "ResourceGather", "ResourceGatherNd"]:
# There can be one or more Identity or control flow ops in between the
# ReadVariableOp and VarHandleOp. Store the ops with the associated
# dtypes.
source_op_names = [get_input_name(node)]
candidate_resource_op_types = {}
while (source_op_names and map_name_to_node[source_op_names[0]].op in
_CONTROL_FLOW_OP_NAMES_OR_IDENTITY):
source_op_name = source_op_names.pop()
current_node = map_name_to_node[source_op_name]
if (source_op_name not in resource_op_types and
source_op_name not in candidate_resource_op_types):
candidate_resource_op_types[source_op_name] = node.attr["dtype"]
source_op_names.append(get_input_name(current_node))
if current_node == "Merge":
merge_resource_name = get_input_name(current_node, index=1)
if (merge_resource_name not in resource_op_types
and merge_resource_name not in candidate_resource_op_types):
candidate_resource_op_types[merge_resource_name] = (
node.attr["dtype"])
source_op_names.append(
get_input_name(map_name_to_node[merge_resource_name]))
should_convert_all = None
for source_node in source_op_names:
if map_name_to_node[source_node].op != "VarHandleOp":
raise ValueError("Cannot find the variable that is an input "
"to the ReadVariableOp.")
should_convert_node = _should_convert(
source_node, variable_names_whitelist, variable_names_blacklist)
if should_convert_all is None:
should_convert_all = should_convert_node
elif should_convert_all != should_convert_node:
raise RuntimeError(
"Found DT_RESOURCE node whose ancestor Variables are both "
"blacklisted AND whitelisted for freezing. Originating "
"descendant node: {}. Ancestor variables: {}.".format(
node.name, source_op_names))
if should_convert_all in (None, True):
resource_op_types.update(candidate_resource_op_types)
# Gets map of variables and the associated data.
if variable_names:
returned_variables = sess.run(variable_names)
else:
returned_variables = []
variables_data_map = dict(zip(variable_dict_names, returned_variables))
logging.info("Froze %d variables.", len(returned_variables))
def _should_convert_ancestor(node):
input_node = map_name_to_node[_node_name(node.input[0])]
while (input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY and
input_node.attr["T"].type == dtypes.resource):
input_node = map_name_to_node[_node_name(input_node.input[0])]
return _should_convert(input_node.name,
variable_names_whitelist,
variable_names_blacklist)
# Reconstruct the graph with constants in place of variables.
output_graph_def = graph_pb2.GraphDef()
how_many_converted = 0
for input_node in inference_graph.node:
output_node = node_def_pb2.NodeDef()
if input_node.name in variables_data_map:
data = variables_data_map[input_node.name]
output_node = create_const_op(input_node.name, input_node.attr["dtype"],
data, data.shape)
how_many_converted += 1
elif input_node.name in resource_op_types:
# Converts the type of the ops between the ReadVariableOp and VarHandleOp
# from RESOURCE_DT to the appropriate type based on the input they are
# referencing. Do not copy shapes due to incorrect shape info.
output_node.op = input_node.op
output_node.name = input_node.name
for in_node in input_node.input:
output_node.input.append(in_node)
for attr_name in input_node.attr:
if str(attr_name) != "_output_shapes":
output_node.attr[attr_name].CopyFrom(input_node.attr[attr_name])
output_node.attr["T"].CopyFrom(resource_op_types[input_node.name])
elif (input_node.op == "ReadVariableOp"
and _should_convert_ancestor(input_node)):
# The first branch converts all VarHandleOps of ResourceVariables to
# constants, so we need to convert the associated ReadVariableOps to
# Identity ops.
output_node.op = "Identity"
output_node.name = input_node.name
output_node.input.extend([input_node.input[0]])
output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
if "_class" in input_node.attr:
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
elif (input_node.op == "ResourceGather"
and _should_convert_ancestor(input_node)):
# The first branch converts all VarHandleOps of ResourceGather to
# constants, so we need to convert the associated ResourceGather to Gather
# ops with a Const axis feeding into it.
if input_node.attr["batch_dims"].i != 0:
raise ValueError("batch_dims != 0 is not supported by freeze_graph.")
axis_data = input_node.attr["batch_dims"].i
axis_node_name = input_node.name + "/axis"
axis_dtype = input_node.attr["Tindices"]
output_axis_node = create_const_op(axis_node_name, axis_dtype, axis_data)
output_graph_def.node.extend([output_axis_node])
output_node.op = "GatherV2"
output_node.name = input_node.name
output_node.input.extend(
[input_node.input[0], input_node.input[1], axis_node_name])
output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
output_node.attr["Taxis"].CopyFrom(axis_dtype)
if "_class" in input_node.attr:
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
elif (input_node.op == "ResourceGatherNd"
and _should_convert_ancestor(input_node)):
output_node.op = "GatherNd"
output_node.name = input_node.name
output_node.input.extend(
[input_node.input[0], input_node.input[1]])
output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
if "_class" in input_node.attr:
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
else:
output_node.CopyFrom(input_node)
output_graph_def.node.append(output_node)
# Update the types of the DT_RESOURCE Identity nodes that do not have an
# associated ReadVariableOp.
resource_identities = []
for node in output_graph_def.node:
if node.op == "Identity" and node.attr["T"].type == dtypes.resource:
resource_identities.append(node)
if resource_identities:
_update_resource_identities(resource_identities,
output_graph_def,
variable_names_whitelist,
variable_names_blacklist)
output_graph_def.library.CopyFrom(inference_graph.library)
logging.info("Converted %d variables to const ops.", how_many_converted)
return output_graph_def
@deprecation.deprecated( @deprecation.deprecated(

View File

@ -21,27 +21,15 @@ from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import node_def_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import math_ops as math_ops_lib
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training.saver import export_meta_graph
# Utility device function to use for testing # Utility device function to use for testing
@ -316,203 +304,5 @@ class DeviceFunctionsTest(test.TestCase):
graph_util.remove_training_nodes(graph_def)) graph_util.remove_training_nodes(graph_def))
class ConvertVariablesToConstantsTest(test.TestCase):
def _ensure_no_variables_in_graph(self, graph_def):
"""Ensures there are no variables in the graph."""
for node in graph_def.node:
self.assertNotIn(
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
def _test_variable_to_const_conversion(self, use_resource):
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=use_resource):
variable_node = variable_scope.get_variable(
"variable_node", initializer=1.0)
another_variable = variable_scope.get_variable(
"unused_variable_node", initializer=1.0)
output_node = math_ops_lib.multiply(
variable_node, 2.0, name="output_node")
with session.Session() as sess:
self.evaluate(variable_node.initializer)
output = self.evaluate(output_node)
self.assertNear(2.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# First get the constant_graph_def when variable_names_whitelist is
# set, note that if variable_names_whitelist is not set an error will
# be thrown because unused_variable_node is not initialized.
constant_graph_def = graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node"],
variable_names_whitelist=set(["variable_node"]))
# Then initialize the unused variable, and get another
# constant_graph_def when variable_names_whitelist is not set.
self.evaluate(another_variable.initializer)
constant_graph_def_without_variable_whitelist = (
graph_util.convert_variables_to_constants(
sess, variable_graph_def, ["output_node"]))
# The unused variable should be cleared so the two graphs should be
# equivalent.
self.assertEqual(
str(constant_graph_def),
str(constant_graph_def_without_variable_whitelist))
# Test variable name black list. This should result in the variable
# not being a const.
constant_graph_def_with_blacklist = (
graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node"],
variable_names_blacklist=set(["variable_node"])))
variable_node = None
for node in constant_graph_def_with_blacklist.node:
if node.name == "variable_node":
variable_node = node
self.assertIsNotNone(variable_node)
if use_resource:
self.assertEqual(variable_node.op, "VarHandleOp")
else:
self.assertEqual(variable_node.op, "VariableV2")
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.
with ops.Graph().as_default():
_ = importer.import_graph_def(constant_graph_def, name="")
self.assertEqual(4, len(constant_graph_def.node))
self._ensure_no_variables_in_graph(constant_graph_def)
with session.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
output = self.evaluate(output_node)
self.assertNear(2.0, output, 0.00001)
def test_resource_variable_can_be_written_after_blacklisting(self):
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
variable_node = variable_scope.get_variable(
"variable_node", initializer=1.0)
another_variable = variable_scope.get_variable(
"unused_variable_node", initializer=2.0)
with ops.control_dependencies([
variable_node.assign(another_variable + variable_node)]):
output_node = array_ops.identity(variable_node, name="output_node")
initializer_name = variable_node.initializer.name
with session.Session() as sess:
self.evaluate(variable_node.initializer)
self.evaluate(another_variable.initializer)
output = self.evaluate(output_node)
self.assertNear(3.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# Test variable name black list. This should result in the variable
# not being a const. Furthermore, the paths that read from and assign
# to the blacklisted variable should continue to be valid.
constant_graph_def_with_blacklist = (
graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node", initializer_name],
variable_names_blacklist=set(["variable_node"])))
variable_node = None
for node in constant_graph_def_with_blacklist.node:
if node.name == "variable_node":
variable_node = node
self.assertIsNotNone(variable_node)
self.assertEqual(variable_node.op, "VarHandleOp")
# Now we make sure another_variable is now a constant, but the original
# variable is not, and that the graph can be executed and update the
# variable can be updated with each execution.
with ops.Graph().as_default():
_ = importer.import_graph_def(constant_graph_def_with_blacklist, name="")
with session.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
self.evaluate(sess.graph.get_operation_by_name(initializer_name))
output = self.evaluate(output_node)
self.assertNear(3.0, output, 0.00001)
output = self.evaluate(output_node)
self.assertNear(5.0, output, 0.00001)
def _inline_functions(self, graph_def, arrays):
meta_graph = export_meta_graph(graph_def=graph_def)
fetch_collection = meta_graph_pb2.CollectionDef()
for name in arrays:
fetch_collection.node_list.value.append(name)
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
# Initialize RewriterConfig with everything disabled except function
# inlining.
config = config_pb2.ConfigProto()
rewrite_options = config.graph_options.rewrite_options
rewrite_options.optimizers.append("function")
return tf_optimizer.OptimizeGraph(config, meta_graph)
def _test_convert_variables_with_functions(self, inline_functions):
"""Freezes a graph with functions."""
@function.Defun(dtypes.float32)
def plus_one(x):
return x + 1.0
with ops.Graph().as_default():
variable_node = variables.Variable(1.0, name="variable_node")
_ = variables.Variable(1.0, name="unused_variable_node")
defun_node = plus_one(variable_node)
_ = math_ops_lib.multiply(defun_node, 2.0, name="output_node")
with session.Session() as sess:
self.evaluate(variables.variables_initializer([variable_node]))
variable_graph_def = sess.graph.as_graph_def()
if inline_functions:
# Run Grappler to create the VarOpHandle --> Placeholder -->
# ResourceVariable pattern.
variable_graph_def = self._inline_functions(
variable_graph_def, ["variable_node", "output_node"])
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, ["output_node"])
self._ensure_no_variables_in_graph(constant_graph_def)
def testReferenceVariables(self):
"""Freezes a graph with reference variables."""
self._test_variable_to_const_conversion(use_resource=False)
def testResourceVariables(self):
"""Freezes a graph with resource variables."""
self._test_variable_to_const_conversion(use_resource=True)
def testWithFunctions(self):
"""Freezes a graph with functions."""
self._test_convert_variables_with_functions(inline_functions=False)
def testWithInlinedFunctions(self):
"""Freezes a graph with functions that have been inlined using Grappler."""
self._test_convert_variables_with_functions(inline_functions=True)
def testGraphWithSwitch(self):
"""Freezes a graph which contains a Switch with type RESOURCE_DT."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
x = variable_scope.get_variable("var_x", initializer=1.0)
y = variable_scope.get_variable("var_y", initializer=2.0)
f1 = lambda: variable_scope.get_variable("var_f1", initializer=17.0)
f2 = lambda: variable_scope.get_variable("var_f2", initializer=23.0)
cond_node = control_flow_ops.case([(gen_math_ops.less(x, y), f1)],
default=f2)
_ = math_ops_lib.multiply(cond_node, 2.0, name="output_node")
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, ["output_node"])
self._ensure_no_variables_in_graph(constant_graph_def)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()