- assertEquals -> assertEqual - assertRaisesRegexp -> assertRegexpMatches - assertRegexpMatches -> assertRegex PiperOrigin-RevId: 319118081 Change-Id: Ieb457128522920ab55d6b69a7f244ab798a7d689
309 lines
12 KiB
Python
309 lines
12 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for tensorflow.python.client.graph_util."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.core.framework import attr_value_pb2
|
|
from tensorflow.core.framework import graph_pb2
|
|
from tensorflow.core.framework import node_def_pb2
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import graph_util
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import gen_state_ops
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import test
|
|
|
|
|
|
# Utility device function to use for testing
|
|
def test_device_func_pin_variable_to_cpu(op):
|
|
if op.device:
|
|
return op.device
|
|
return "/cpu:0" if op.node_def.op in ["Variable", "VariableV2"] else op.device
|
|
|
|
|
|
class DeviceFunctionsTest(test.TestCase):
|
|
|
|
def testTwoDeviceFunctions(self):
|
|
with ops.Graph().as_default() as g:
|
|
var_0 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_0",
|
|
container="",
|
|
shared_name="")
|
|
with g.device(test_device_func_pin_variable_to_cpu):
|
|
var_1 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_1",
|
|
container="",
|
|
shared_name="")
|
|
var_2 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_2",
|
|
container="",
|
|
shared_name="")
|
|
var_3 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_3",
|
|
container="",
|
|
shared_name="")
|
|
with g.device(test_device_func_pin_variable_to_cpu):
|
|
var_4 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_4",
|
|
container="",
|
|
shared_name="")
|
|
with g.device("/device:GPU:0"):
|
|
var_5 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_5",
|
|
container="",
|
|
shared_name="")
|
|
var_6 = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="var_6",
|
|
container="",
|
|
shared_name="")
|
|
|
|
self.assertDeviceEqual(var_0.device, None)
|
|
self.assertDeviceEqual(var_1.device, "/device:CPU:0")
|
|
self.assertDeviceEqual(var_2.device, None)
|
|
self.assertDeviceEqual(var_3.device, None)
|
|
self.assertDeviceEqual(var_4.device, "/device:CPU:0")
|
|
self.assertDeviceEqual(var_5.device, "/device:GPU:0")
|
|
self.assertDeviceEqual(var_6.device, "/device:CPU:0")
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testNestedDeviceFunctions(self):
|
|
with ops.Graph().as_default():
|
|
var_0 = variables.VariableV1(0)
|
|
with ops.device(test_device_func_pin_variable_to_cpu):
|
|
var_1 = variables.VariableV1(1)
|
|
with ops.device(lambda op: "/device:GPU:0"):
|
|
var_2 = variables.VariableV1(2)
|
|
with ops.device("/device:GPU:0"): # Implicit merging device function.
|
|
var_3 = variables.VariableV1(3)
|
|
|
|
self.assertDeviceEqual(var_0.device, None)
|
|
self.assertDeviceEqual(var_1.device, "/device:CPU:0")
|
|
self.assertDeviceEqual(var_2.device, "/device:GPU:0")
|
|
self.assertDeviceEqual(var_3.device, "/device:GPU:0")
|
|
|
|
def testExplicitDevice(self):
|
|
with ops.Graph().as_default() as g:
|
|
const_0 = constant_op.constant(5.0)
|
|
with g.device("/device:GPU:0"):
|
|
const_1 = constant_op.constant(5.0)
|
|
with g.device("/device:GPU:1"):
|
|
const_2 = constant_op.constant(5.0)
|
|
with g.device("/device:CPU:0"):
|
|
const_3 = constant_op.constant(5.0)
|
|
with g.device("/device:CPU:1"):
|
|
const_4 = constant_op.constant(5.0)
|
|
with g.device("/job:ps"):
|
|
const_5 = constant_op.constant(5.0)
|
|
|
|
self.assertDeviceEqual(const_0.device, None)
|
|
self.assertDeviceEqual(const_1.device, "/device:GPU:0")
|
|
self.assertDeviceEqual(const_2.device, "/device:GPU:1")
|
|
self.assertDeviceEqual(const_3.device, "/device:CPU:0")
|
|
self.assertDeviceEqual(const_4.device, "/device:CPU:1")
|
|
self.assertDeviceEqual(const_5.device, "/job:ps")
|
|
|
|
def testDefaultDevice(self):
|
|
with ops.Graph().as_default() as g, g.device(
|
|
test_device_func_pin_variable_to_cpu):
|
|
with g.device("/job:ps"):
|
|
const_0 = constant_op.constant(5.0)
|
|
with g.device("/device:GPU:0"):
|
|
const_1 = constant_op.constant(5.0)
|
|
with g.device("/device:GPU:1"):
|
|
const_2 = constant_op.constant(5.0)
|
|
with g.device("/device:CPU:0"):
|
|
const_3 = constant_op.constant(5.0)
|
|
with g.device("/device:CPU:1"):
|
|
const_4 = constant_op.constant(5.0)
|
|
with g.device("/replica:0"):
|
|
const_5 = constant_op.constant(5.0)
|
|
|
|
self.assertDeviceEqual(const_0.device, "/job:ps")
|
|
self.assertDeviceEqual(const_1.device, "/device:GPU:0")
|
|
self.assertDeviceEqual(const_2.device, "/device:GPU:1")
|
|
self.assertDeviceEqual(const_3.device, "/device:CPU:0")
|
|
self.assertDeviceEqual(const_4.device, "/device:CPU:1")
|
|
self.assertDeviceEqual(const_5.device, "/replica:0")
|
|
|
|
def testExtractSubGraph(self):
|
|
graph_def = graph_pb2.GraphDef()
|
|
n1 = graph_def.node.add()
|
|
n1.name = "n1"
|
|
n1.input.extend(["n5"])
|
|
n2 = graph_def.node.add()
|
|
n2.name = "n2"
|
|
# Take the first output of the n1 node as the input.
|
|
n2.input.extend(["n1:0"])
|
|
n3 = graph_def.node.add()
|
|
n3.name = "n3"
|
|
# Add a control input (which isn't really needed by the kernel, but
|
|
# rather to enforce execution order between nodes).
|
|
n3.input.extend(["^n2"])
|
|
n4 = graph_def.node.add()
|
|
n4.name = "n4"
|
|
|
|
# It is fine to have a loops in the graph as well.
|
|
n5 = graph_def.node.add()
|
|
n5.name = "n5"
|
|
n5.input.extend(["n1"])
|
|
|
|
sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
|
|
self.assertEqual("n1", sub_graph.node[0].name)
|
|
self.assertEqual("n2", sub_graph.node[1].name)
|
|
self.assertEqual("n3", sub_graph.node[2].name)
|
|
self.assertEqual("n5", sub_graph.node[3].name)
|
|
|
|
def testExtractSubGraphWithInvalidDestNodes(self):
|
|
graph_def = graph_pb2.GraphDef()
|
|
n1 = graph_def.node.add()
|
|
n1.name = "n1"
|
|
with self.assertRaisesRegex(TypeError, "must be a list"):
|
|
graph_util.extract_sub_graph(graph_def, "n1")
|
|
|
|
def create_node_def(self, op, name, inputs):
|
|
new_node = node_def_pb2.NodeDef()
|
|
new_node.op = op
|
|
new_node.name = name
|
|
new_node.input.extend(inputs)
|
|
return new_node
|
|
|
|
def create_constant_node_def(self,
|
|
name,
|
|
value,
|
|
dtype,
|
|
shape=None,
|
|
inputs=None):
|
|
node = self.create_node_def("Const", name, inputs or [])
|
|
self.set_attr_dtype(node, "dtype", dtype)
|
|
self.set_attr_tensor(node, "value", value, dtype, shape)
|
|
return node
|
|
|
|
def set_attr_dtype(self, node, key, value):
|
|
node.attr[key].CopyFrom(
|
|
attr_value_pb2.AttrValue(type=value.as_datatype_enum))
|
|
|
|
def set_attr_tensor(self, node, key, value, dtype, shape=None):
|
|
node.attr[key].CopyFrom(
|
|
attr_value_pb2.AttrValue(
|
|
tensor=tensor_util.make_tensor_proto(
|
|
value, dtype=dtype, shape=shape)))
|
|
|
|
def testRemoveTrainingNodes(self):
|
|
a_constant_name = "a_constant"
|
|
b_constant_name = "b_constant"
|
|
a_check_name = "a_check"
|
|
b_check_name = "b_check"
|
|
a_identity_name = "a_identity"
|
|
b_identity_name = "b_identity"
|
|
add_name = "add"
|
|
graph_def = graph_pb2.GraphDef()
|
|
a_constant = self.create_constant_node_def(
|
|
a_constant_name, value=1, dtype=dtypes.float32, shape=[])
|
|
graph_def.node.extend([a_constant])
|
|
a_check_node = self.create_node_def("CheckNumerics", a_check_name,
|
|
[a_constant_name])
|
|
graph_def.node.extend([a_check_node])
|
|
a_identity_node = self.create_node_def(
|
|
"Identity", a_identity_name, [a_constant_name, "^" + a_check_name])
|
|
graph_def.node.extend([a_identity_node])
|
|
b_constant = self.create_constant_node_def(
|
|
b_constant_name, value=1, dtype=dtypes.float32, shape=[])
|
|
graph_def.node.extend([b_constant])
|
|
b_check_node = self.create_node_def("CheckNumerics", b_check_name,
|
|
[b_constant_name])
|
|
graph_def.node.extend([b_check_node])
|
|
b_identity_node = self.create_node_def(
|
|
"Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
|
|
graph_def.node.extend([b_identity_node])
|
|
add_node = self.create_node_def("Add", add_name,
|
|
[a_identity_name, b_identity_name])
|
|
self.set_attr_dtype(add_node, "T", dtypes.float32)
|
|
graph_def.node.extend([add_node])
|
|
|
|
expected_output = graph_pb2.GraphDef()
|
|
a_constant = self.create_constant_node_def(
|
|
a_constant_name, value=1, dtype=dtypes.float32, shape=[])
|
|
expected_output.node.extend([a_constant])
|
|
b_constant = self.create_constant_node_def(
|
|
b_constant_name, value=1, dtype=dtypes.float32, shape=[])
|
|
expected_output.node.extend([b_constant])
|
|
add_node = self.create_node_def("Add", add_name,
|
|
[a_constant_name, b_constant_name])
|
|
self.set_attr_dtype(add_node, "T", dtypes.float32)
|
|
expected_output.node.extend([add_node])
|
|
|
|
output = graph_util.remove_training_nodes(graph_def)
|
|
self.assertProtoEquals(expected_output, output)
|
|
|
|
def testRemoveIdentityChains(self):
|
|
"""Check that chains of Identity nodes are correctly pruned.
|
|
|
|
Create a chain of four nodes, A, B, C, and D where A inputs B, B inputs C,
|
|
and C inputs D. Nodes B and C are "Identity" and should be pruned, resulting
|
|
in the nodes A and D, where A inputs D.
|
|
"""
|
|
graph_def = graph_pb2.GraphDef()
|
|
graph_def.node.extend([
|
|
self.create_node_def("Aop", "A", ["B"]),
|
|
self.create_node_def("Identity", "B", ["C"]),
|
|
self.create_node_def("Identity", "C", ["D"]),
|
|
self.create_node_def("Dop", "D", [])
|
|
])
|
|
|
|
expected_graph_def = graph_pb2.GraphDef()
|
|
expected_graph_def.node.extend([
|
|
self.create_node_def("Aop", "A", ["D"]),
|
|
self.create_node_def("Dop", "D", [])
|
|
])
|
|
|
|
self.assertProtoEquals(expected_graph_def,
|
|
graph_util.remove_training_nodes(graph_def))
|
|
|
|
def testRemoveIdentityUsedAsControlInputInConst(self):
|
|
"""Check that Identity nodes used as control inputs are not removed."""
|
|
graph_def = graph_pb2.GraphDef()
|
|
graph_def.node.extend([
|
|
self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]),
|
|
self.create_node_def("Identity", "I", ["Base"]),
|
|
self.create_node_def("BaseOp", "Base", [])
|
|
])
|
|
|
|
self.assertProtoEquals(graph_def,
|
|
graph_util.remove_training_nodes(graph_def))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|