STT-tensorflow/tensorflow/python/framework/graph_util_test.py
Eugene Brevdo 96c84224c8 Some feature additions to saved_model_cli AOT compile & bugfixes in freeze graph
1. saved_model_cli now properly identifies read-only and non-readonly vars in a graph
   and only freezes the readonly variables (and marks the others as not readonly).
2. fixed bugs in convert_variables_to_constants where blacklist/whitelist was not
   properly respected for chains of operations.  while the VarHandleOp node was properly
   blacklisted, downstream ops that used the resources would be converted as if the
   variable had been frozen.  so for example the following graph would break:

   VarHandleOp -> Identity -> [first arg in] ResourceAssign

   The attrs of the Identity op would be changed to DT_FLOAT though it should stay as
   DT_RESOURCE.
3. Added support for freezing of *Nd ops (ResourceGatherNd, ResourceScatterNd).

PiperOrigin-RevId: 293239272
Change-Id: I06de3d139c5585a93ba585f076edf92137e4c48a
2020-02-04 15:16:22 -08:00

621 lines
25 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.client.graph_util."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python import keras
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import math_ops # pylint: disable=unused-import
from tensorflow.python.ops import math_ops as math_ops_lib
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training.saver import export_meta_graph
# Utility device function to use for testing
def test_device_func_pin_variable_to_cpu(op):
if op.device:
return op.device
return "/cpu:0" if op.node_def.op in ["Variable", "VariableV2"] else op.device
class DeviceFunctionsTest(test.TestCase):
def testTwoDeviceFunctions(self):
with ops.Graph().as_default() as g:
var_0 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_0",
container="",
shared_name="")
with g.device(test_device_func_pin_variable_to_cpu):
var_1 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_1",
container="",
shared_name="")
var_2 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_2",
container="",
shared_name="")
var_3 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_3",
container="",
shared_name="")
with g.device(test_device_func_pin_variable_to_cpu):
var_4 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_4",
container="",
shared_name="")
with g.device("/device:GPU:0"):
var_5 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_5",
container="",
shared_name="")
var_6 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_6",
container="",
shared_name="")
self.assertDeviceEqual(var_0.device, None)
self.assertDeviceEqual(var_1.device, "/device:CPU:0")
self.assertDeviceEqual(var_2.device, None)
self.assertDeviceEqual(var_3.device, None)
self.assertDeviceEqual(var_4.device, "/device:CPU:0")
self.assertDeviceEqual(var_5.device, "/device:GPU:0")
self.assertDeviceEqual(var_6.device, "/device:CPU:0")
@test_util.run_v1_only("b/120545219")
def testNestedDeviceFunctions(self):
with ops.Graph().as_default():
var_0 = variables.VariableV1(0)
with ops.device(test_device_func_pin_variable_to_cpu):
var_1 = variables.VariableV1(1)
with ops.device(lambda op: "/device:GPU:0"):
var_2 = variables.VariableV1(2)
with ops.device("/device:GPU:0"): # Implicit merging device function.
var_3 = variables.VariableV1(3)
self.assertDeviceEqual(var_0.device, None)
self.assertDeviceEqual(var_1.device, "/device:CPU:0")
self.assertDeviceEqual(var_2.device, "/device:GPU:0")
self.assertDeviceEqual(var_3.device, "/device:GPU:0")
def testExplicitDevice(self):
with ops.Graph().as_default() as g:
const_0 = constant_op.constant(5.0)
with g.device("/device:GPU:0"):
const_1 = constant_op.constant(5.0)
with g.device("/device:GPU:1"):
const_2 = constant_op.constant(5.0)
with g.device("/device:CPU:0"):
const_3 = constant_op.constant(5.0)
with g.device("/device:CPU:1"):
const_4 = constant_op.constant(5.0)
with g.device("/job:ps"):
const_5 = constant_op.constant(5.0)
self.assertDeviceEqual(const_0.device, None)
self.assertDeviceEqual(const_1.device, "/device:GPU:0")
self.assertDeviceEqual(const_2.device, "/device:GPU:1")
self.assertDeviceEqual(const_3.device, "/device:CPU:0")
self.assertDeviceEqual(const_4.device, "/device:CPU:1")
self.assertDeviceEqual(const_5.device, "/job:ps")
def testDefaultDevice(self):
with ops.Graph().as_default() as g, g.device(
test_device_func_pin_variable_to_cpu):
with g.device("/job:ps"):
const_0 = constant_op.constant(5.0)
with g.device("/device:GPU:0"):
const_1 = constant_op.constant(5.0)
with g.device("/device:GPU:1"):
const_2 = constant_op.constant(5.0)
with g.device("/device:CPU:0"):
const_3 = constant_op.constant(5.0)
with g.device("/device:CPU:1"):
const_4 = constant_op.constant(5.0)
with g.device("/replica:0"):
const_5 = constant_op.constant(5.0)
self.assertDeviceEqual(const_0.device, "/job:ps")
self.assertDeviceEqual(const_1.device, "/device:GPU:0")
self.assertDeviceEqual(const_2.device, "/device:GPU:1")
self.assertDeviceEqual(const_3.device, "/device:CPU:0")
self.assertDeviceEqual(const_4.device, "/device:CPU:1")
self.assertDeviceEqual(const_5.device, "/replica:0")
def testExtractSubGraph(self):
graph_def = graph_pb2.GraphDef()
n1 = graph_def.node.add()
n1.name = "n1"
n1.input.extend(["n5"])
n2 = graph_def.node.add()
n2.name = "n2"
# Take the first output of the n1 node as the input.
n2.input.extend(["n1:0"])
n3 = graph_def.node.add()
n3.name = "n3"
# Add a control input (which isn't really needed by the kernel, but
# rather to enforce execution order between nodes).
n3.input.extend(["^n2"])
n4 = graph_def.node.add()
n4.name = "n4"
# It is fine to have a loops in the graph as well.
n5 = graph_def.node.add()
n5.name = "n5"
n5.input.extend(["n1"])
sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
self.assertEqual("n1", sub_graph.node[0].name)
self.assertEqual("n2", sub_graph.node[1].name)
self.assertEqual("n3", sub_graph.node[2].name)
self.assertEqual("n5", sub_graph.node[3].name)
def testExtractSubGraphWithInvalidDestNodes(self):
graph_def = graph_pb2.GraphDef()
n1 = graph_def.node.add()
n1.name = "n1"
with self.assertRaisesRegexp(TypeError, "must be a list"):
graph_util.extract_sub_graph(graph_def, "n1")
def create_node_def(self, op, name, inputs):
new_node = node_def_pb2.NodeDef()
new_node.op = op
new_node.name = name
new_node.input.extend(inputs)
return new_node
def create_constant_node_def(self,
name,
value,
dtype,
shape=None,
inputs=None):
node = self.create_node_def("Const", name, inputs or [])
self.set_attr_dtype(node, "dtype", dtype)
self.set_attr_tensor(node, "value", value, dtype, shape)
return node
def set_attr_dtype(self, node, key, value):
node.attr[key].CopyFrom(
attr_value_pb2.AttrValue(type=value.as_datatype_enum))
def set_attr_tensor(self, node, key, value, dtype, shape=None):
node.attr[key].CopyFrom(
attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(
value, dtype=dtype, shape=shape)))
def testRemoveTrainingNodes(self):
a_constant_name = "a_constant"
b_constant_name = "b_constant"
a_check_name = "a_check"
b_check_name = "b_check"
a_identity_name = "a_identity"
b_identity_name = "b_identity"
add_name = "add"
graph_def = graph_pb2.GraphDef()
a_constant = self.create_constant_node_def(
a_constant_name, value=1, dtype=dtypes.float32, shape=[])
graph_def.node.extend([a_constant])
a_check_node = self.create_node_def("CheckNumerics", a_check_name,
[a_constant_name])
graph_def.node.extend([a_check_node])
a_identity_node = self.create_node_def(
"Identity", a_identity_name, [a_constant_name, "^" + a_check_name])
graph_def.node.extend([a_identity_node])
b_constant = self.create_constant_node_def(
b_constant_name, value=1, dtype=dtypes.float32, shape=[])
graph_def.node.extend([b_constant])
b_check_node = self.create_node_def("CheckNumerics", b_check_name,
[b_constant_name])
graph_def.node.extend([b_check_node])
b_identity_node = self.create_node_def(
"Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
graph_def.node.extend([b_identity_node])
add_node = self.create_node_def("Add", add_name,
[a_identity_name, b_identity_name])
self.set_attr_dtype(add_node, "T", dtypes.float32)
graph_def.node.extend([add_node])
expected_output = graph_pb2.GraphDef()
a_constant = self.create_constant_node_def(
a_constant_name, value=1, dtype=dtypes.float32, shape=[])
expected_output.node.extend([a_constant])
b_constant = self.create_constant_node_def(
b_constant_name, value=1, dtype=dtypes.float32, shape=[])
expected_output.node.extend([b_constant])
add_node = self.create_node_def("Add", add_name,
[a_constant_name, b_constant_name])
self.set_attr_dtype(add_node, "T", dtypes.float32)
expected_output.node.extend([add_node])
output = graph_util.remove_training_nodes(graph_def)
self.assertProtoEquals(expected_output, output)
def testRemoveIdentityChains(self):
"""Check that chains of Identity nodes are correctly pruned.
Create a chain of four nodes, A, B, C, and D where A inputs B, B inputs C,
and C inputs D. Nodes B and C are "Identity" and should be pruned, resulting
in the nodes A and D, where A inputs D.
"""
graph_def = graph_pb2.GraphDef()
graph_def.node.extend([
self.create_node_def("Aop", "A", ["B"]),
self.create_node_def("Identity", "B", ["C"]),
self.create_node_def("Identity", "C", ["D"]),
self.create_node_def("Dop", "D", [])
])
expected_graph_def = graph_pb2.GraphDef()
expected_graph_def.node.extend([
self.create_node_def("Aop", "A", ["D"]),
self.create_node_def("Dop", "D", [])
])
self.assertProtoEquals(expected_graph_def,
graph_util.remove_training_nodes(graph_def))
def testRemoveIdentityUsedAsControlInputInConst(self):
"""Check that Identity nodes used as control inputs are not removed."""
graph_def = graph_pb2.GraphDef()
graph_def.node.extend([
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
self.create_node_def("Identity", "I", ["Base"]),
self.create_node_def("BaseOp", "Base", [])
])
self.assertProtoEquals(graph_def,
graph_util.remove_training_nodes(graph_def))
class ConvertVariablesToConstantsTest(test.TestCase):
def _get_tensors(self, sess, tensor_list):
"""Returns a list of Tensor objects from the Session."""
return [
sess.graph.get_tensor_by_name(tensor.name) for tensor in tensor_list
]
def _get_tensor_names(self, tensors):
"""Returns a list of string names for the tensors specified."""
return [tensor.name.split(":")[0] for tensor in tensors]
def _evaluate_graph_def(self, graph_def, inputs, outputs, input_data):
"""Evaluates the GraphDef using Sessions."""
with ops.Graph().as_default() as graph:
importer.import_graph_def(graph_def, name="")
sess = session.Session(graph=graph)
input_tensors = self._get_tensors(sess, inputs)
output_tensors = self._get_tensors(sess, outputs)
return sess.run(
output_tensors, feed_dict=dict(zip(input_tensors, input_data)))
def _ensure_no_variables_in_graph(self, graph_def):
"""Ensures there are no variables in the graph."""
for node in graph_def.node:
self.assertNotIn(
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
def _test_converted_keras_model(self, model, constant_graph_def, input_data):
"""Compares the converted Keras model."""
expected_value = model.predict(input_data)
actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs,
model.outputs, [input_data])
np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5)
def _test_variable_to_const_conversion(self, use_resource):
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=use_resource):
variable_node = variable_scope.get_variable(
"variable_node", initializer=1.0)
another_variable = variable_scope.get_variable(
"unused_variable_node", initializer=1.0)
output_node = math_ops_lib.multiply(
variable_node, 2.0, name="output_node")
with session.Session() as sess:
self.evaluate(variable_node.initializer)
output = self.evaluate(output_node)
self.assertNear(2.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# First get the constant_graph_def when variable_names_whitelist is
# set, note that if variable_names_whitelist is not set an error will
# be thrown because unused_variable_node is not initialized.
constant_graph_def = graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node"],
variable_names_whitelist=set(["variable_node"]))
# Then initialize the unused variable, and get another
# constant_graph_def when variable_names_whitelist is not set.
self.evaluate(another_variable.initializer)
constant_graph_def_without_variable_whitelist = (
graph_util.convert_variables_to_constants(
sess, variable_graph_def, ["output_node"]))
# The unused variable should be cleared so the two graphs should be
# equivalent.
self.assertEqual(
str(constant_graph_def),
str(constant_graph_def_without_variable_whitelist))
# Test variable name black list. This should result in the variable
# not being a const.
constant_graph_def_with_blacklist = (
graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node"],
variable_names_blacklist=set(["variable_node"])))
variable_node = None
for node in constant_graph_def_with_blacklist.node:
if node.name == "variable_node":
variable_node = node
self.assertIsNotNone(variable_node)
if use_resource:
self.assertEqual(variable_node.op, "VarHandleOp")
else:
self.assertEqual(variable_node.op, "VariableV2")
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.
with ops.Graph().as_default():
_ = importer.import_graph_def(constant_graph_def, name="")
self.assertEqual(4, len(constant_graph_def.node))
self._ensure_no_variables_in_graph(constant_graph_def)
with session.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
output = self.evaluate(output_node)
self.assertNear(2.0, output, 0.00001)
def test_resource_variable_can_be_written_after_blacklisting(self):
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
variable_node = variable_scope.get_variable(
"variable_node", initializer=1.0)
another_variable = variable_scope.get_variable(
"unused_variable_node", initializer=2.0)
with ops.control_dependencies([
variable_node.assign(another_variable + variable_node)]):
output_node = array_ops.identity(variable_node, name="output_node")
initializer_name = variable_node.initializer.name
with session.Session() as sess:
self.evaluate(variable_node.initializer)
self.evaluate(another_variable.initializer)
output = self.evaluate(output_node)
self.assertNear(3.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# Test variable name black list. This should result in the variable
# not being a const. Furthermore, the paths that read from and assign
# to the blacklisted variable should continue to be valid.
constant_graph_def_with_blacklist = (
graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node", initializer_name],
variable_names_blacklist=set(["variable_node"])))
variable_node = None
for node in constant_graph_def_with_blacklist.node:
if node.name == "variable_node":
variable_node = node
self.assertIsNotNone(variable_node)
self.assertEqual(variable_node.op, "VarHandleOp")
# Now we make sure another_variable is now a constant, but the original
# variable is not, and that the graph can be executed and update the
# variable can be updated with each execution.
with ops.Graph().as_default():
_ = importer.import_graph_def(constant_graph_def_with_blacklist, name="")
with session.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
self.evaluate(sess.graph.get_operation_by_name(initializer_name))
output = self.evaluate(output_node)
self.assertNear(3.0, output, 0.00001)
output = self.evaluate(output_node)
self.assertNear(5.0, output, 0.00001)
def _inline_functions(self, graph_def, arrays):
meta_graph = export_meta_graph(graph_def=graph_def)
fetch_collection = meta_graph_pb2.CollectionDef()
for name in arrays:
fetch_collection.node_list.value.append(name)
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
# Initialize RewriterConfig with everything disabled except function
# inlining.
config = config_pb2.ConfigProto()
rewrite_options = config.graph_options.rewrite_options
rewrite_options.optimizers.append("function")
return tf_optimizer.OptimizeGraph(config, meta_graph)
def _test_convert_variables_with_functions(self, inline_functions):
"""Freezes a graph with functions."""
@function.Defun(dtypes.float32)
def plus_one(x):
return x + 1.0
with ops.Graph().as_default():
variable_node = variables.Variable(1.0, name="variable_node")
_ = variables.Variable(1.0, name="unused_variable_node")
defun_node = plus_one(variable_node)
_ = math_ops_lib.multiply(defun_node, 2.0, name="output_node")
with session.Session() as sess:
self.evaluate(variables.variables_initializer([variable_node]))
variable_graph_def = sess.graph.as_graph_def()
if inline_functions:
# Run Grappler to create the VarOpHandle --> Placeholder -->
# ResourceVariable pattern.
variable_graph_def = self._inline_functions(
variable_graph_def, ["variable_node", "output_node"])
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, ["output_node"])
self._ensure_no_variables_in_graph(constant_graph_def)
def testReferenceVariables(self):
"""Freezes a graph with reference variables."""
self._test_variable_to_const_conversion(use_resource=False)
def testResourceVariables(self):
"""Freezes a graph with resource variables."""
self._test_variable_to_const_conversion(use_resource=True)
def testWithFunctions(self):
"""Freezes a graph with functions."""
self._test_convert_variables_with_functions(inline_functions=False)
def testWithInlinedFunctions(self):
"""Freezes a graph with functions that have been inlined using Grappler."""
self._test_convert_variables_with_functions(inline_functions=True)
def testWithEmbeddings(self):
"""Freezes a graph with embeddings."""
ops.disable_eager_execution()
state_input = keras.layers.Input(
shape=(1,), name="state_input", dtype="int32")
output = keras.layers.Embedding(
output_dim=16, input_dim=100, input_length=1, name="state")(
state_input)
model = keras.models.Model(inputs=[state_input], outputs=[output])
model.compile(
loss={"state": "sparse_categorical_crossentropy"}, optimizer="adam")
# Freeze the graph.
sess = keras.backend.get_session()
variable_graph_def = sess.graph_def
output_tensor = self._get_tensor_names(model.outputs)
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, output_tensor)
# Validate converted graph.
input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32)
self._ensure_no_variables_in_graph(constant_graph_def)
self._test_converted_keras_model(model, constant_graph_def, input_data)
def testGraphWithSwitch(self):
"""Freezes a graph which contains a Switch with type RESOURCE_DT."""
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
x = variable_scope.get_variable("var_x", initializer=1.0)
y = variable_scope.get_variable("var_y", initializer=2.0)
f1 = lambda: variable_scope.get_variable("var_f1", initializer=17.0)
f2 = lambda: variable_scope.get_variable("var_f2", initializer=23.0)
cond_node = control_flow_ops.case([(gen_math_ops.less(x, y), f1)],
default=f2)
_ = math_ops_lib.multiply(cond_node, 2.0, name="output_node")
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
variable_graph_def = sess.graph.as_graph_def()
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, ["output_node"])
self._ensure_no_variables_in_graph(constant_graph_def)
def testKerasBatchNorm(self):
"""Freezes a graph with Keras batch norm."""
ops.disable_eager_execution()
inputs = keras.layers.Input(shape=(128, 128, 1))
batch_norm = keras.layers.BatchNormalization()(inputs)
model = keras.models.Model(inputs, batch_norm, name="test")
model.compile(
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
tensor_names = [tensor.name for tensor in model.inputs + model.outputs]
# Freeze the graph.
sess = keras.backend.get_session()
variable_graph_def = sess.graph_def
variable_graph_def = self._inline_functions(variable_graph_def,
tensor_names)
output_tensor = self._get_tensor_names(model.outputs)
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, output_tensor)
# Validate converted graph.
input_data = np.array(
np.random.random_sample([1, 128, 128, 1]), dtype=np.int32)
self._ensure_no_variables_in_graph(constant_graph_def)
self._test_converted_keras_model(model, constant_graph_def, input_data)
def testLSTM(self):
"""Freezes a Keras LSTM."""
ops.disable_eager_execution()
model = keras.models.Sequential(
[keras.layers.LSTM(units=10, input_shape=(10, 10))])
tensor_names = [tensor.name for tensor in model.inputs + model.outputs]
# Freeze the model.
sess = keras.backend.get_session()
variable_graph_def = sess.graph_def
variable_graph_def = self._inline_functions(variable_graph_def,
tensor_names)
output_tensor = self._get_tensor_names(model.outputs)
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, output_tensor)
# Validate converted graph.
input_data = np.array(np.random.random_sample([10, 10, 10]), dtype=np.int32)
self._ensure_no_variables_in_graph(constant_graph_def)
self._test_converted_keras_model(model, constant_graph_def, input_data)
if __name__ == "__main__":
test.main()