Before this change there were two different code paths for dealing with graph freezing for v1 and v2 graphs. They largely did the same thing, but each path had and lacked capabilities the other had, and each had its own bugs. This change re-writes the previous v2 logic so it can cope with session-based graphs and allow for the conversion of a subset of the variables, and changes the previous convert_variables_to_constants call to proxy into the new logic. The new logic is built around a more "graphy" algorithm: variables are converted to constants, and that conversion is then propagated through the graph by following the graph edges. This hopefully makes it easier to understand what is going on, and to change it later on. More granular tests were added, in order to check that the right graph manipulations were performed. In order to do that, some graph merging infrastructure had to be created in the test. PiperOrigin-RevId: 315497124 Change-Id: I3a33acc804b5dc9628c208df8fd1b7c59f906ddb
309 lines
12 KiB
Python
309 lines
12 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for tensorflow.python.client.graph_util."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.core.framework import attr_value_pb2
|
|
from tensorflow.core.framework import graph_pb2
|
|
from tensorflow.core.framework import node_def_pb2
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import graph_util
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import gen_state_ops
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import test
|
|
|
|
|
|
# Utility device function to use for testing
|
|
def test_device_func_pin_variable_to_cpu(op):
|
|
if op.device:
|
|
return op.device
|
|
return "/cpu:0" if op.node_def.op in ["Variable", "VariableV2"] else op.device
|
|
|
|
|
|
class DeviceFunctionsTest(test.TestCase):
|
|
|
|
def testTwoDeviceFunctions(self):
|
|
with ops.Graph().as_default() as g:
|
|
var_0 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_0",
|
|
container="",
|
|
shared_name="")
|
|
with g.device(test_device_func_pin_variable_to_cpu):
|
|
var_1 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_1",
|
|
container="",
|
|
shared_name="")
|
|
var_2 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_2",
|
|
container="",
|
|
shared_name="")
|
|
var_3 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_3",
|
|
container="",
|
|
shared_name="")
|
|
with g.device(test_device_func_pin_variable_to_cpu):
|
|
var_4 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_4",
|
|
container="",
|
|
shared_name="")
|
|
with g.device("/device:GPU:0"):
|
|
var_5 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_5",
|
|
container="",
|
|
shared_name="")
|
|
var_6 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_6",
|
|
container="",
|
|
shared_name="")
|
|
|
|
self.assertDeviceEqual(var_0.device, None)
|
|
self.assertDeviceEqual(var_1.device, "/device:CPU:0")
|
|
self.assertDeviceEqual(var_2.device, None)
|
|
self.assertDeviceEqual(var_3.device, None)
|
|
self.assertDeviceEqual(var_4.device, "/device:CPU:0")
|
|
self.assertDeviceEqual(var_5.device, "/device:GPU:0")
|
|
self.assertDeviceEqual(var_6.device, "/device:CPU:0")
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testNestedDeviceFunctions(self):
|
|
with ops.Graph().as_default():
|
|
var_0 = variables.VariableV1(0)
|
|
with ops.device(test_device_func_pin_variable_to_cpu):
|
|
var_1 = variables.VariableV1(1)
|
|
with ops.device(lambda op: "/device:GPU:0"):
|
|
var_2 = variables.VariableV1(2)
|
|
with ops.device("/device:GPU:0"): # Implicit merging device function.
|
|
var_3 = variables.VariableV1(3)
|
|
|
|
self.assertDeviceEqual(var_0.device, None)
|
|
self.assertDeviceEqual(var_1.device, "/device:CPU:0")
|
|
self.assertDeviceEqual(var_2.device, "/device:GPU:0")
|
|
self.assertDeviceEqual(var_3.device, "/device:GPU:0")
|
|
|
|
def testExplicitDevice(self):
|
|
with ops.Graph().as_default() as g:
|
|
const_0 = constant_op.constant(5.0)
|
|
with g.device("/device:GPU:0"):
|
|
const_1 = constant_op.constant(5.0)
|
|
with g.device("/device:GPU:1"):
|
|
const_2 = constant_op.constant(5.0)
|
|
with g.device("/device:CPU:0"):
|
|
const_3 = constant_op.constant(5.0)
|
|
with g.device("/device:CPU:1"):
|
|
const_4 = constant_op.constant(5.0)
|
|
with g.device("/job:ps"):
|
|
const_5 = constant_op.constant(5.0)
|
|
|
|
self.assertDeviceEqual(const_0.device, None)
|
|
self.assertDeviceEqual(const_1.device, "/device:GPU:0")
|
|
self.assertDeviceEqual(const_2.device, "/device:GPU:1")
|
|
self.assertDeviceEqual(const_3.device, "/device:CPU:0")
|
|
self.assertDeviceEqual(const_4.device, "/device:CPU:1")
|
|
self.assertDeviceEqual(const_5.device, "/job:ps")
|
|
|
|
def testDefaultDevice(self):
|
|
with ops.Graph().as_default() as g, g.device(
|
|
test_device_func_pin_variable_to_cpu):
|
|
with g.device("/job:ps"):
|
|
const_0 = constant_op.constant(5.0)
|
|
with g.device("/device:GPU:0"):
|
|
const_1 = constant_op.constant(5.0)
|
|
with g.device("/device:GPU:1"):
|
|
const_2 = constant_op.constant(5.0)
|
|
with g.device("/device:CPU:0"):
|
|
const_3 = constant_op.constant(5.0)
|
|
with g.device("/device:CPU:1"):
|
|
const_4 = constant_op.constant(5.0)
|
|
with g.device("/replica:0"):
|
|
const_5 = constant_op.constant(5.0)
|
|
|
|
self.assertDeviceEqual(const_0.device, "/job:ps")
|
|
self.assertDeviceEqual(const_1.device, "/device:GPU:0")
|
|
self.assertDeviceEqual(const_2.device, "/device:GPU:1")
|
|
self.assertDeviceEqual(const_3.device, "/device:CPU:0")
|
|
self.assertDeviceEqual(const_4.device, "/device:CPU:1")
|
|
self.assertDeviceEqual(const_5.device, "/replica:0")
|
|
|
|
def testExtractSubGraph(self):
|
|
graph_def = graph_pb2.GraphDef()
|
|
n1 = graph_def.node.add()
|
|
n1.name = "n1"
|
|
n1.input.extend(["n5"])
|
|
n2 = graph_def.node.add()
|
|
n2.name = "n2"
|
|
# Take the first output of the n1 node as the input.
|
|
n2.input.extend(["n1:0"])
|
|
n3 = graph_def.node.add()
|
|
n3.name = "n3"
|
|
# Add a control input (which isn't really needed by the kernel, but
|
|
# rather to enforce execution order between nodes).
|
|
n3.input.extend(["^n2"])
|
|
n4 = graph_def.node.add()
|
|
n4.name = "n4"
|
|
|
|
# It is fine to have a loops in the graph as well.
|
|
n5 = graph_def.node.add()
|
|
n5.name = "n5"
|
|
n5.input.extend(["n1"])
|
|
|
|
sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
|
|
self.assertEqual("n1", sub_graph.node[0].name)
|
|
self.assertEqual("n2", sub_graph.node[1].name)
|
|
self.assertEqual("n3", sub_graph.node[2].name)
|
|
self.assertEqual("n5", sub_graph.node[3].name)
|
|
|
|
def testExtractSubGraphWithInvalidDestNodes(self):
|
|
graph_def = graph_pb2.GraphDef()
|
|
n1 = graph_def.node.add()
|
|
n1.name = "n1"
|
|
with self.assertRaisesRegexp(TypeError, "must be a list"):
|
|
graph_util.extract_sub_graph(graph_def, "n1")
|
|
|
|
def create_node_def(self, op, name, inputs):
|
|
new_node = node_def_pb2.NodeDef()
|
|
new_node.op = op
|
|
new_node.name = name
|
|
new_node.input.extend(inputs)
|
|
return new_node
|
|
|
|
def create_constant_node_def(self,
|
|
name,
|
|
value,
|
|
dtype,
|
|
shape=None,
|
|
inputs=None):
|
|
node = self.create_node_def("Const", name, inputs or [])
|
|
self.set_attr_dtype(node, "dtype", dtype)
|
|
self.set_attr_tensor(node, "value", value, dtype, shape)
|
|
return node
|
|
|
|
def set_attr_dtype(self, node, key, value):
|
|
node.attr[key].CopyFrom(
|
|
attr_value_pb2.AttrValue(type=value.as_datatype_enum))
|
|
|
|
def set_attr_tensor(self, node, key, value, dtype, shape=None):
|
|
node.attr[key].CopyFrom(
|
|
attr_value_pb2.AttrValue(
|
|
tensor=tensor_util.make_tensor_proto(
|
|
value, dtype=dtype, shape=shape)))
|
|
|
|
def testRemoveTrainingNodes(self):
|
|
a_constant_name = "a_constant"
|
|
b_constant_name = "b_constant"
|
|
a_check_name = "a_check"
|
|
b_check_name = "b_check"
|
|
a_identity_name = "a_identity"
|
|
b_identity_name = "b_identity"
|
|
add_name = "add"
|
|
graph_def = graph_pb2.GraphDef()
|
|
a_constant = self.create_constant_node_def(
|
|
a_constant_name, value=1, dtype=dtypes.float32, shape=[])
|
|
graph_def.node.extend([a_constant])
|
|
a_check_node = self.create_node_def("CheckNumerics", a_check_name,
|
|
[a_constant_name])
|
|
graph_def.node.extend([a_check_node])
|
|
a_identity_node = self.create_node_def(
|
|
"Identity", a_identity_name, [a_constant_name, "^" + a_check_name])
|
|
graph_def.node.extend([a_identity_node])
|
|
b_constant = self.create_constant_node_def(
|
|
b_constant_name, value=1, dtype=dtypes.float32, shape=[])
|
|
graph_def.node.extend([b_constant])
|
|
b_check_node = self.create_node_def("CheckNumerics", b_check_name,
|
|
[b_constant_name])
|
|
graph_def.node.extend([b_check_node])
|
|
b_identity_node = self.create_node_def(
|
|
"Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
|
|
graph_def.node.extend([b_identity_node])
|
|
add_node = self.create_node_def("Add", add_name,
|
|
[a_identity_name, b_identity_name])
|
|
self.set_attr_dtype(add_node, "T", dtypes.float32)
|
|
graph_def.node.extend([add_node])
|
|
|
|
expected_output = graph_pb2.GraphDef()
|
|
a_constant = self.create_constant_node_def(
|
|
a_constant_name, value=1, dtype=dtypes.float32, shape=[])
|
|
expected_output.node.extend([a_constant])
|
|
b_constant = self.create_constant_node_def(
|
|
b_constant_name, value=1, dtype=dtypes.float32, shape=[])
|
|
expected_output.node.extend([b_constant])
|
|
add_node = self.create_node_def("Add", add_name,
|
|
[a_constant_name, b_constant_name])
|
|
self.set_attr_dtype(add_node, "T", dtypes.float32)
|
|
expected_output.node.extend([add_node])
|
|
|
|
output = graph_util.remove_training_nodes(graph_def)
|
|
self.assertProtoEquals(expected_output, output)
|
|
|
|
def testRemoveIdentityChains(self):
|
|
"""Check that chains of Identity nodes are correctly pruned.
|
|
|
|
Create a chain of four nodes, A, B, C, and D where A inputs B, B inputs C,
|
|
and C inputs D. Nodes B and C are "Identity" and should be pruned, resulting
|
|
in the nodes A and D, where A inputs D.
|
|
"""
|
|
graph_def = graph_pb2.GraphDef()
|
|
graph_def.node.extend([
|
|
self.create_node_def("Aop", "A", ["B"]),
|
|
self.create_node_def("Identity", "B", ["C"]),
|
|
self.create_node_def("Identity", "C", ["D"]),
|
|
self.create_node_def("Dop", "D", [])
|
|
])
|
|
|
|
expected_graph_def = graph_pb2.GraphDef()
|
|
expected_graph_def.node.extend([
|
|
self.create_node_def("Aop", "A", ["D"]),
|
|
self.create_node_def("Dop", "D", [])
|
|
])
|
|
|
|
self.assertProtoEquals(expected_graph_def,
|
|
graph_util.remove_training_nodes(graph_def))
|
|
|
|
def testRemoveIdentityUsedAsControlInputInConst(self):
|
|
"""Check that Identity nodes used as control inputs are not removed."""
|
|
graph_def = graph_pb2.GraphDef()
|
|
graph_def.node.extend([
|
|
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
|
|
self.create_node_def("Identity", "I", ["Base"]),
|
|
self.create_node_def("BaseOp", "Base", [])
|
|
])
|
|
|
|
self.assertProtoEquals(graph_def,
|
|
graph_util.remove_training_nodes(graph_def))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|