STT-tensorflow/tensorflow/python/framework/graph_util_test.py
Cesar Crusius ed52277026 Unify convert to constants logic.
Before this change there were two different code paths for dealing
with graph freezing for v1 and v2 graphs. They largely did the same
thing, but each path had and lacked capabilities the other had, and
each had its own bugs.

This change re-writes the previous v2 logic so it can cope with
session-based graphs and allow for the conversion of a subset of the
variables, and changes the previous convert_variables_to_constants
call to proxy into the new logic.

The new logic is built around a more "graphy" algorithm: variables are
converted to constants, and that conversion is then propagated through
the graph by following the graph edges. This hopefully makes it easier
to understand what is going on, and to change it later on.

More granular tests were added, in order to check that the right graph
manipulations were performed. In order to do that, some graph merging
infrastructure had to be created in the test.

PiperOrigin-RevId: 315497124
Change-Id: I3a33acc804b5dc9628c208df8fd1b7c59f906ddb
2020-06-09 09:32:59 -07:00

309 lines
12 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.client.graph_util."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
# Utility device function to use for testing
def test_device_func_pin_variable_to_cpu(op):
if op.device:
return op.device
return "/cpu:0" if op.node_def.op in ["Variable", "VariableV2"] else op.device
class DeviceFunctionsTest(test.TestCase):
def testTwoDeviceFunctions(self):
with ops.Graph().as_default() as g:
var_0 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_0",
container="",
shared_name="")
with g.device(test_device_func_pin_variable_to_cpu):
var_1 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_1",
container="",
shared_name="")
var_2 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_2",
container="",
shared_name="")
var_3 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_3",
container="",
shared_name="")
with g.device(test_device_func_pin_variable_to_cpu):
var_4 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_4",
container="",
shared_name="")
with g.device("/device:GPU:0"):
var_5 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_5",
container="",
shared_name="")
var_6 = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
name="var_6",
container="",
shared_name="")
self.assertDeviceEqual(var_0.device, None)
self.assertDeviceEqual(var_1.device, "/device:CPU:0")
self.assertDeviceEqual(var_2.device, None)
self.assertDeviceEqual(var_3.device, None)
self.assertDeviceEqual(var_4.device, "/device:CPU:0")
self.assertDeviceEqual(var_5.device, "/device:GPU:0")
self.assertDeviceEqual(var_6.device, "/device:CPU:0")
@test_util.run_v1_only("b/120545219")
def testNestedDeviceFunctions(self):
with ops.Graph().as_default():
var_0 = variables.VariableV1(0)
with ops.device(test_device_func_pin_variable_to_cpu):
var_1 = variables.VariableV1(1)
with ops.device(lambda op: "/device:GPU:0"):
var_2 = variables.VariableV1(2)
with ops.device("/device:GPU:0"): # Implicit merging device function.
var_3 = variables.VariableV1(3)
self.assertDeviceEqual(var_0.device, None)
self.assertDeviceEqual(var_1.device, "/device:CPU:0")
self.assertDeviceEqual(var_2.device, "/device:GPU:0")
self.assertDeviceEqual(var_3.device, "/device:GPU:0")
def testExplicitDevice(self):
with ops.Graph().as_default() as g:
const_0 = constant_op.constant(5.0)
with g.device("/device:GPU:0"):
const_1 = constant_op.constant(5.0)
with g.device("/device:GPU:1"):
const_2 = constant_op.constant(5.0)
with g.device("/device:CPU:0"):
const_3 = constant_op.constant(5.0)
with g.device("/device:CPU:1"):
const_4 = constant_op.constant(5.0)
with g.device("/job:ps"):
const_5 = constant_op.constant(5.0)
self.assertDeviceEqual(const_0.device, None)
self.assertDeviceEqual(const_1.device, "/device:GPU:0")
self.assertDeviceEqual(const_2.device, "/device:GPU:1")
self.assertDeviceEqual(const_3.device, "/device:CPU:0")
self.assertDeviceEqual(const_4.device, "/device:CPU:1")
self.assertDeviceEqual(const_5.device, "/job:ps")
def testDefaultDevice(self):
with ops.Graph().as_default() as g, g.device(
test_device_func_pin_variable_to_cpu):
with g.device("/job:ps"):
const_0 = constant_op.constant(5.0)
with g.device("/device:GPU:0"):
const_1 = constant_op.constant(5.0)
with g.device("/device:GPU:1"):
const_2 = constant_op.constant(5.0)
with g.device("/device:CPU:0"):
const_3 = constant_op.constant(5.0)
with g.device("/device:CPU:1"):
const_4 = constant_op.constant(5.0)
with g.device("/replica:0"):
const_5 = constant_op.constant(5.0)
self.assertDeviceEqual(const_0.device, "/job:ps")
self.assertDeviceEqual(const_1.device, "/device:GPU:0")
self.assertDeviceEqual(const_2.device, "/device:GPU:1")
self.assertDeviceEqual(const_3.device, "/device:CPU:0")
self.assertDeviceEqual(const_4.device, "/device:CPU:1")
self.assertDeviceEqual(const_5.device, "/replica:0")
def testExtractSubGraph(self):
graph_def = graph_pb2.GraphDef()
n1 = graph_def.node.add()
n1.name = "n1"
n1.input.extend(["n5"])
n2 = graph_def.node.add()
n2.name = "n2"
# Take the first output of the n1 node as the input.
n2.input.extend(["n1:0"])
n3 = graph_def.node.add()
n3.name = "n3"
# Add a control input (which isn't really needed by the kernel, but
# rather to enforce execution order between nodes).
n3.input.extend(["^n2"])
n4 = graph_def.node.add()
n4.name = "n4"
# It is fine to have a loops in the graph as well.
n5 = graph_def.node.add()
n5.name = "n5"
n5.input.extend(["n1"])
sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
self.assertEqual("n1", sub_graph.node[0].name)
self.assertEqual("n2", sub_graph.node[1].name)
self.assertEqual("n3", sub_graph.node[2].name)
self.assertEqual("n5", sub_graph.node[3].name)
def testExtractSubGraphWithInvalidDestNodes(self):
graph_def = graph_pb2.GraphDef()
n1 = graph_def.node.add()
n1.name = "n1"
with self.assertRaisesRegexp(TypeError, "must be a list"):
graph_util.extract_sub_graph(graph_def, "n1")
def create_node_def(self, op, name, inputs):
new_node = node_def_pb2.NodeDef()
new_node.op = op
new_node.name = name
new_node.input.extend(inputs)
return new_node
def create_constant_node_def(self,
name,
value,
dtype,
shape=None,
inputs=None):
node = self.create_node_def("Const", name, inputs or [])
self.set_attr_dtype(node, "dtype", dtype)
self.set_attr_tensor(node, "value", value, dtype, shape)
return node
def set_attr_dtype(self, node, key, value):
node.attr[key].CopyFrom(
attr_value_pb2.AttrValue(type=value.as_datatype_enum))
def set_attr_tensor(self, node, key, value, dtype, shape=None):
node.attr[key].CopyFrom(
attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(
value, dtype=dtype, shape=shape)))
def testRemoveTrainingNodes(self):
a_constant_name = "a_constant"
b_constant_name = "b_constant"
a_check_name = "a_check"
b_check_name = "b_check"
a_identity_name = "a_identity"
b_identity_name = "b_identity"
add_name = "add"
graph_def = graph_pb2.GraphDef()
a_constant = self.create_constant_node_def(
a_constant_name, value=1, dtype=dtypes.float32, shape=[])
graph_def.node.extend([a_constant])
a_check_node = self.create_node_def("CheckNumerics", a_check_name,
[a_constant_name])
graph_def.node.extend([a_check_node])
a_identity_node = self.create_node_def(
"Identity", a_identity_name, [a_constant_name, "^" + a_check_name])
graph_def.node.extend([a_identity_node])
b_constant = self.create_constant_node_def(
b_constant_name, value=1, dtype=dtypes.float32, shape=[])
graph_def.node.extend([b_constant])
b_check_node = self.create_node_def("CheckNumerics", b_check_name,
[b_constant_name])
graph_def.node.extend([b_check_node])
b_identity_node = self.create_node_def(
"Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
graph_def.node.extend([b_identity_node])
add_node = self.create_node_def("Add", add_name,
[a_identity_name, b_identity_name])
self.set_attr_dtype(add_node, "T", dtypes.float32)
graph_def.node.extend([add_node])
expected_output = graph_pb2.GraphDef()
a_constant = self.create_constant_node_def(
a_constant_name, value=1, dtype=dtypes.float32, shape=[])
expected_output.node.extend([a_constant])
b_constant = self.create_constant_node_def(
b_constant_name, value=1, dtype=dtypes.float32, shape=[])
expected_output.node.extend([b_constant])
add_node = self.create_node_def("Add", add_name,
[a_constant_name, b_constant_name])
self.set_attr_dtype(add_node, "T", dtypes.float32)
expected_output.node.extend([add_node])
output = graph_util.remove_training_nodes(graph_def)
self.assertProtoEquals(expected_output, output)
def testRemoveIdentityChains(self):
"""Check that chains of Identity nodes are correctly pruned.
Create a chain of four nodes, A, B, C, and D where A inputs B, B inputs C,
and C inputs D. Nodes B and C are "Identity" and should be pruned, resulting
in the nodes A and D, where A inputs D.
"""
graph_def = graph_pb2.GraphDef()
graph_def.node.extend([
self.create_node_def("Aop", "A", ["B"]),
self.create_node_def("Identity", "B", ["C"]),
self.create_node_def("Identity", "C", ["D"]),
self.create_node_def("Dop", "D", [])
])
expected_graph_def = graph_pb2.GraphDef()
expected_graph_def.node.extend([
self.create_node_def("Aop", "A", ["D"]),
self.create_node_def("Dop", "D", [])
])
self.assertProtoEquals(expected_graph_def,
graph_util.remove_training_nodes(graph_def))
def testRemoveIdentityUsedAsControlInputInConst(self):
"""Check that Identity nodes used as control inputs are not removed."""
graph_def = graph_pb2.GraphDef()
graph_def.node.extend([
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
self.create_node_def("Identity", "I", ["Base"]),
self.create_node_def("BaseOp", "Base", [])
])
self.assertProtoEquals(graph_def,
graph_util.remove_training_nodes(graph_def))
if __name__ == "__main__":
test.main()