STT-tensorflow/tensorflow/python/framework/meta_graph_test.py
A. Unique TensorFlower 0554618036 Include Ops that are used via PartitionedCalls to MetaGraphDef.MetaInfoDef.stripped_op_list
PiperOrigin-RevId: 322706036
Change-Id: I3f307d07a9d38aeca34f7c550d857c76aed37005
2020-07-22 19:39:13 -07:00

1049 lines
43 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for tensorflow.python.framework.meta_graph.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os.path
import random
import shutil
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import function
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import queue_runner_impl
# pylint: disable=invalid-name
def _TestDir(test_name):
test_dir = os.path.join(test.get_temp_dir(), test_name)
if os.path.exists(test_dir):
shutil.rmtree(test_dir)
gfile.MakeDirs(test_dir)
return test_dir
# pylint: enable=invalid-name
class SimpleMetaGraphTest(test.TestCase):
@test_util.run_deprecated_v1
def testNoVariables(self):
test_dir = _TestDir("no_variables")
filename = os.path.join(test_dir, "metafile")
input_feed_value = -10 # Arbitrary input value for feed_dict.
orig_graph = ops.Graph()
with self.session(graph=orig_graph) as sess:
# Create a minimal graph with zero variables.
input_tensor = array_ops.placeholder(
dtypes.float32, shape=[], name="input")
offset = constant_op.constant(42, dtype=dtypes.float32, name="offset")
output_tensor = math_ops.add(input_tensor, offset, name="add_offset")
# Add input and output tensors to graph collections.
ops.add_to_collection("input_tensor", input_tensor)
ops.add_to_collection("output_tensor", output_tensor)
output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
self.assertEqual(output_value, 32)
# Generates MetaGraphDef.
meta_graph_def, var_list = meta_graph.export_scoped_meta_graph(
filename=filename,
graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
collection_list=["input_tensor", "output_tensor"],
saver_def=None)
self.assertTrue(meta_graph_def.HasField("meta_info_def"))
self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")
self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,
"")
self.assertEqual({}, var_list)
# Create a clean graph and import the MetaGraphDef nodes.
new_graph = ops.Graph()
with self.session(graph=new_graph) as sess:
# Import the previously export meta graph.
meta_graph.import_scoped_meta_graph(filename)
# Re-exports the current graph state for comparison to the original.
new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(filename +
"_new")
test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
new_meta_graph_def)
# Ensures that we can still get a reference to our graph collections.
new_input_tensor = ops.get_collection("input_tensor")[0]
new_output_tensor = ops.get_collection("output_tensor")[0]
# Verifies that the new graph computes the same result as the original.
new_output_value = sess.run(new_output_tensor,
{new_input_tensor: input_feed_value})
self.assertEqual(new_output_value, output_value)
@test_util.run_deprecated_v1
def testStrippedOpListNestedFunctions(self):
with self.cached_session():
# Square two levels deep
@function.Defun(dtypes.int32)
def f0(x):
return math_ops.square(x)
@function.Defun(dtypes.int32)
def f1(x):
return f0(x)
# At this point we've defined two functions but haven't called them, so
# there should be no used ops.
op_list = meta_graph.stripped_op_list_for_graph(ops.get_default_graph()
.as_graph_def())
self.assertEqual(len(op_list.op), 0)
# If we call the function on a constant, there should be two ops
_ = f1(constant_op.constant(7))
op_list = meta_graph.stripped_op_list_for_graph(ops.get_default_graph()
.as_graph_def())
self.assertEqual(["Const", "Square"], [op.name for op in op_list.op])
def testStrippedOpListRecursiveFunctions(self):
# The function module doesn't support recursive functions, so we build a
# recursive function situation by ourselves: A calls B calls A and Const.
graph = graph_pb2.GraphDef()
a = graph.library.function.add()
b = graph.library.function.add()
a.signature.name = "A"
b.signature.name = "B"
a.node_def.add().op = "B"
b.node_def.add().op = "Const"
b.node_def.add().op = "A"
# Use A in the graph
graph.node.add().op = "A"
# The stripped op list should contain just Const.
op_list = meta_graph.stripped_op_list_for_graph(graph)
self.assertEqual(["Const"], [op.name for op in op_list.op])
def testStrippedOpListPartitionedCalls(self):
# Function A calls B via StatefulPartitionedCall.
graph = graph_pb2.GraphDef()
a = graph.library.function.add()
b = graph.library.function.add()
a.signature.name = "A"
b.signature.name = "B"
node_in_a = a.node_def.add()
node_in_a.op = "StatefulPartitionedCall"
node_in_a.attr["f"].func.name = "B"
b.node_def.add().op = "Const"
b.node_def.add().op = "A"
# Use A in the graph via PartitionedCall.
node = graph.node.add()
node.op = "PartitionedCall"
node.attr["f"].func.name = "A"
op_list = meta_graph.stripped_op_list_for_graph(graph)
self.assertSameElements(
["Const", "PartitionedCall", "StatefulPartitionedCall"],
[op.name for op in op_list.op])
@test_util.run_deprecated_v1
def testDefaultAttrStripping(self):
"""Verifies that default attributes are stripped from a graph def."""
# Complex Op has 2 attributes with defaults:
# o "T" : float32.
# o "Tout" : complex64.
# When inputs to the Complex Op are float32 instances, "T" maps to float32
# and "Tout" maps to complex64. Since these attr values map to their
# defaults, they must be stripped unless stripping of default attrs is
# disabled.
with self.cached_session():
real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
# strip_default_attrs is enabled.
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
graph_def=ops.get_default_graph().as_graph_def(),
strip_default_attrs=True)
node_def = test_util.get_node_def_from_graph("complex",
meta_graph_def.graph_def)
self.assertNotIn("T", node_def.attr)
self.assertNotIn("Tout", node_def.attr)
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
# strip_default_attrs is disabled.
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
graph_def=ops.get_default_graph().as_graph_def(),
strip_default_attrs=False)
node_def = test_util.get_node_def_from_graph("complex",
meta_graph_def.graph_def)
self.assertIn("T", node_def.attr)
self.assertIn("Tout", node_def.attr)
self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs)
# When inputs to the Complex Op are float64 instances, "T" maps to float64
# and "Tout" maps to complex128. Since these attr values don't map to their
# defaults, they must not be stripped.
with self.session(graph=ops.Graph()):
real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real")
imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
graph_def=ops.get_default_graph().as_graph_def(),
strip_default_attrs=True)
node_def = test_util.get_node_def_from_graph("complex",
meta_graph_def.graph_def)
self.assertEqual(node_def.attr["T"].type, dtypes.float64)
self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128)
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
@test_util.run_deprecated_v1
def testDefaultAttrStrippingNestedFunctions(self):
"""Verifies that default attributes are stripped from function node defs."""
with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32)
def f0(i, j):
return math_ops.complex(i, j, name="double_nested_complex")
@function.Defun(dtypes.float32, dtypes.float32)
def f1(i, j):
return f0(i, j)
_ = f1(constant_op.constant(1.0), constant_op.constant(2.0))
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
graph_def=ops.get_default_graph().as_graph_def(),
strip_default_attrs=True)
double_nested_complex_node_def = None
for function_def in meta_graph_def.graph_def.library.function:
for node_def in function_def.node_def:
if node_def.name.startswith("double_nested_complex"):
double_nested_complex_node_def = node_def
break
if double_nested_complex_node_def:
break
self.assertIsNotNone(double_nested_complex_node_def)
self.assertNotIn("T", double_nested_complex_node_def.attr)
self.assertNotIn("Tout", double_nested_complex_node_def.attr)
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
def testDefaultAttrStrippingUnregisteredOps(self):
"""Verifies that nodes with un-registered ops are not stripped."""
graph_def = graph_pb2.GraphDef()
node = graph_def.node.add()
node.name = "node_with_unreg_op"
node.op = "unreg_op"
node.attr["attr_1"].i = 1
meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
meta_info_def.stripped_op_list.op.add()
with self.cached_session():
meta_graph_def = meta_graph.create_meta_graph_def(
meta_info_def=meta_info_def, graph_def=graph_def,
strip_default_attrs=True)
node_def = test_util.get_node_def_from_graph("node_with_unreg_op",
meta_graph_def.graph_def)
self.assertEqual(node_def.attr["attr_1"].i, 1)
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
@test_util.run_deprecated_v1
def testVariableObjectsAreSharedAmongCollections(self):
with ops.Graph().as_default() as graph1:
v = variables.Variable(3.0)
# A single instance of Variable is shared among the collections:
global_vars = graph1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
trainable_vars = graph1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual(len(global_vars), 1)
self.assertEqual(len(trainable_vars), 1)
self.assertIs(global_vars[0], trainable_vars[0])
self.assertIs(v, global_vars[0])
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1)
del graph1 # To avoid accidental references in code involving graph2.
with ops.Graph().as_default() as graph2:
meta_graph.import_scoped_meta_graph(orig_meta_graph)
global_vars = graph2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
trainable_vars = graph2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual(len(global_vars), 1)
self.assertEqual(len(trainable_vars), 1)
# A single instance of Variable is shared among the collections:
self.assertIs(global_vars[0], trainable_vars[0])
@test_util.run_deprecated_v1
def testMetricVariablesCollectionLoadsBytesList(self):
with ops.Graph().as_default() as graph1:
v1 = variables.Variable(
[1, 2, 3], shape=[3], dtype=dtypes.float64, name="v")
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1)
# Copy bytes list from global variables collection to metric variables.
orig_meta_graph.collection_def[ops.GraphKeys.METRIC_VARIABLES].CopyFrom(
orig_meta_graph.collection_def["variables"])
with ops.Graph().as_default() as graph2:
meta_graph.import_scoped_meta_graph(orig_meta_graph)
var_list = graph2.get_collection(ops.GraphKeys.METRIC_VARIABLES)
self.assertEqual(len(var_list), 1)
v2 = var_list[0]
self.assertIsInstance(v2, variables.Variable)
self.assertEqual(v1.name, v2.name)
self.assertEqual(v1.dtype, v2.dtype)
self.assertEqual(v1.shape, v2.shape)
class ScopedMetaGraphTest(test.TestCase):
def _testScopedExport(self, test_dir, exported_filenames):
graph = ops.Graph()
with graph.as_default():
# Creates an inference graph.
# Hidden 1
colocate_constraint = constant_op.constant(1.2, name="constraint")
images = constant_op.constant(
1.2, dtypes.float32, shape=[100, 28], name="images")
with ops.name_scope("hidden1"):
with graph.colocate_with(colocate_constraint.op):
weights1 = variables.Variable(
random_ops.truncated_normal(
[28, 128], stddev=1.0 / math.sqrt(float(28))),
name="weights")
# The use of control_flow_ops.cond here is purely for adding test
# coverage the save and restore of control flow context (which doesn't
# make any sense here from a machine learning perspective). The typical
# biases is a simple Variable without the conditions.
biases1 = variables.Variable(
control_flow_ops.cond(
math_ops.less(random.random(), 0.5),
lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
name="biases")
hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1)
# Hidden 2
with ops.name_scope("hidden2"):
weights2 = variables.Variable(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
# The use of control_flow_ops.while_loop here is purely for adding test
# coverage the save and restore of control flow context (which doesn't
# make any sense here from a machine learning perspective). The typical
# biases is a simple Variable without the conditions.
def loop_cond(it, _):
return it < 2
def loop_body(it, biases2):
biases2 += constant_op.constant(0.1, shape=[32])
return it + 1, biases2
_, biases2 = control_flow_ops.while_loop(
loop_cond,
loop_body, [
constant_op.constant(0), variables.Variable(
array_ops.zeros([32]), name="biases")
])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)
# Linear
with ops.name_scope("softmax_linear"):
weights3 = variables.Variable(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
biases3 = variables.Variable(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights3) + biases3
ops.add_to_collection("logits", logits)
# Exports each sub-graph.
# Exports the first one with unbound_inputs_col_name set to default.
orig_meta_graph1, var_list = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, exported_filenames[0]),
graph=ops.get_default_graph(),
export_scope="hidden1")
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
var_names = [v.name for _, v in var_list.items()]
self.assertEqual(["hidden1/biases:0", "hidden1/weights:0"],
sorted(var_names))
# Exports the rest with no unbound_inputs_col_name.
orig_meta_graph2, _ = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, exported_filenames[1]),
graph=ops.get_default_graph(),
export_scope="hidden2",
unbound_inputs_col_name=None)
orig_meta_graph3, _ = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, exported_filenames[2]),
graph=ops.get_default_graph(),
export_scope="softmax_linear",
unbound_inputs_col_name=None)
return [orig_meta_graph1, orig_meta_graph2, orig_meta_graph3]
def _testScopedImport(self, test_dir, exported_filenames):
graph = ops.Graph()
# Create all the missing inputs.
with graph.as_default():
new_image = constant_op.constant(
1.2, dtypes.float32, shape=[100, 28], name="images")
with self.assertRaisesRegex(ValueError, "Graph contains unbound inputs"):
meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filenames[0]),
graph=graph,
import_scope="new_hidden1")
with self.assertRaisesRegex(ValueError, "Graph contains unbound inputs"):
meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filenames[0]),
graph=graph,
input_map={"image:0": new_image},
import_scope="new_hidden1")
# Verifies we can import the original "hidden1" into "new_hidden1".
var_list = meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filenames[0]),
graph=graph,
input_map={"$unbound_inputs_images": new_image},
import_scope="new_hidden1")
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
new_var_names = [v.name for _, v in var_list.items()]
self.assertEqual(["new_hidden1/biases:0", "new_hidden1/weights:0"],
sorted(new_var_names))
# Verifies we can import the original "hidden2" into "new_hidden2".
hidden1 = array_ops.identity(
graph.as_graph_element("new_hidden1/Relu:0"), name="hidden1/Relu")
var_list = meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filenames[1]),
graph=graph,
input_map={"$unbound_inputs_hidden1/Relu": hidden1},
import_scope="new_hidden2",
unbound_inputs_col_name=None)
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
new_var_names = [v.name for _, v in var_list.items()]
self.assertEqual(["new_hidden2/biases:0", "new_hidden2/weights:0"],
sorted(new_var_names))
# Verifies we can import the original "softmax_linear" into
# "new_softmax_linear".
hidden2 = array_ops.identity(
graph.as_graph_element("new_hidden2/Relu:0"), name="hidden2/Relu")
var_list = meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filenames[2]),
graph=graph,
input_map={"$unbound_inputs_hidden2/Relu": hidden2},
import_scope="new_softmax_linear",
unbound_inputs_col_name=None)
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
new_var_names = [v.name for _, v in var_list.items()]
self.assertEqual(
["new_softmax_linear/biases:0", "new_softmax_linear/weights:0"],
sorted(new_var_names))
# Exports the scoped meta graphs again.
new_meta_graph1, var_list = meta_graph.export_scoped_meta_graph(
graph=graph, export_scope="new_hidden1")
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
new_meta_graph2, var_list = meta_graph.export_scoped_meta_graph(
graph=graph, export_scope="new_hidden2", unbound_inputs_col_name=None)
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
new_meta_graph3, var_list = meta_graph.export_scoped_meta_graph(
graph=graph,
export_scope="new_softmax_linear",
unbound_inputs_col_name=None)
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
return [new_meta_graph1, new_meta_graph2, new_meta_graph3]
# Verifies that we can export the subgraph under each layer and import
# them into new layers in a new graph.
@test_util.run_deprecated_v1
def testScopedExportAndImport(self):
test_dir = _TestDir("scoped_export_import")
filenames = [
"exported_hidden1.pbtxt", "exported_hidden2.pbtxt",
"exported_softmax_linear.pbtxt"
]
orig_meta_graphs = self._testScopedExport(test_dir, filenames)
new_meta_graphs = self._testScopedImport(test_dir, filenames)
for a, b in zip(orig_meta_graphs, new_meta_graphs):
# The unbound input strings are slightly different with the C API enabled
# ("images" vs "images:0") due to the original import_graph_def code
# vs. ImportGraphDef in C++.
# TODO(skyewm): update the pbtxts once _USE_C_API is removed.
del a.collection_def["unbound_inputs"]
del b.collection_def["unbound_inputs"]
test_util.assert_meta_graph_protos_equal(self, a, b)
def testWhileLoopGradients(self):
# Create a simple while loop.
with ops.Graph().as_default():
with ops.name_scope("export"):
var = variables.Variable(0.)
var_name = var.name
_, output = control_flow_ops.while_loop(
lambda i, x: i < 5,
lambda i, x: (i + 1, x + math_ops.cast(i, dtypes.float32)),
[0, var])
output_name = output.name
# Generate a MetaGraphDef containing the while loop with an export scope.
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
export_scope="export")
# Build and run the gradients of the while loop. We use this below to
# verify that the gradients are correct with the imported MetaGraphDef.
init_op = variables.global_variables_initializer()
grad = gradients_impl.gradients([output], [var])
with session.Session() as sess:
self.evaluate(init_op)
expected_grad_value = self.evaluate(grad)
# Restore the MetaGraphDef into a new Graph with an import scope.
with ops.Graph().as_default():
meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope="import")
# Re-export and make sure we get the same MetaGraphDef.
new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
export_scope="import")
test_util.assert_meta_graph_protos_equal(
self, meta_graph_def, new_meta_graph_def)
# Make sure we can still build gradients and get the same result.
def new_name(tensor_name):
base_tensor_name = tensor_name.replace("export/", "")
return "import/" + base_tensor_name
var = ops.get_default_graph().get_tensor_by_name(new_name(var_name))
output = ops.get_default_graph().get_tensor_by_name(new_name(output_name))
grad = gradients_impl.gradients([output], [var])
init_op = variables.global_variables_initializer()
with session.Session() as sess:
self.evaluate(init_op)
actual_grad_value = self.evaluate(grad)
self.assertEqual(expected_grad_value, actual_grad_value)
@test_util.run_v1_only("b/120545219")
def testImportWhileLoopInWhileLoop(self):
# Create a simple while loop.
with ops.Graph().as_default():
var = variables.Variable(0.0)
_, output = control_flow_ops.while_loop(lambda i, x: i < 5,
lambda i, x: (i + 1, x * 2.0),
[0, var])
output_name = output.name
# Generate a MetaGraphDef containing the while loop with an export scope.
meta_graph_def, _ = meta_graph.export_scoped_meta_graph()
# Restore the MetaGraphDef in a while loop in a new graph.
with ops.Graph().as_default():
def body(i, _):
meta_graph.import_scoped_meta_graph(meta_graph_def)
return i + 1, ops.get_default_graph().get_tensor_by_name(output_name)
_, x = control_flow_ops.while_loop(lambda i, x: i < 2, body, [0, 0.0],
name="")
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(x)
@test_util.run_deprecated_v1
def testScopedImportUnderNameScope(self):
graph = ops.Graph()
with graph.as_default():
variables.Variable(initial_value=1.0, trainable=True, name="myvar")
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(graph=graph)
graph = ops.Graph()
with graph.as_default():
with ops.name_scope("foo"):
imported_variables = meta_graph.import_scoped_meta_graph(
meta_graph_def, import_scope="bar")
self.assertEqual(len(imported_variables), 1)
self.assertEqual(list(imported_variables.values())[0].name,
"foo/bar/myvar:0")
@test_util.run_deprecated_v1
def testScopedImportUnderNameScopeNoVarScope(self):
graph = ops.Graph()
with graph.as_default():
variables.Variable(initial_value=1.0, trainable=True, name="myvar")
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(graph=graph)
graph = ops.Graph()
with graph.as_default():
with ops.name_scope("foo"):
imported_variables = meta_graph.import_scoped_meta_graph(
meta_graph_def)
self.assertEqual(len(imported_variables), 1)
self.assertEqual(list(imported_variables.values())[0].name,
"foo/myvar:0")
def testImportsUsingSameScopeName(self):
with ops.Graph().as_default():
variables.Variable(0, name="v")
meta_graph_def, _ = meta_graph.export_scoped_meta_graph()
with ops.Graph().as_default():
for suffix in ["", "_1"]:
imported_variables = meta_graph.import_scoped_meta_graph(
meta_graph_def, import_scope="s")
self.assertEqual(len(imported_variables), 1)
self.assertEqual(list(imported_variables.keys())[0], "v:0")
self.assertEqual(list(imported_variables.values())[0].name,
"s" + suffix + "/v:0")
@test_util.run_deprecated_v1
def testScopedImportWithSelectedCollections(self):
meta_graph_filename = os.path.join(
_TestDir("selected_collections_import"), "meta_graph.pb")
graph = ops.Graph()
# Add a variable to populate two collections. The functionality tested is
# not specific to variables, but using variables in the test is convenient.
with graph.as_default():
variables.Variable(initial_value=1.0, trainable=True)
self.assertTrue(
all(
graph.get_collection(key)
for key in
[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES]
))
meta_graph.export_scoped_meta_graph(
filename=meta_graph_filename, graph=graph)
def _test_import(include_collection_keys, omit_collection_keys):
assert set(include_collection_keys).isdisjoint(omit_collection_keys)
newgraph = ops.Graph()
import_scope = "some_scope_name"
def _restore_collections_predicate(collection_key):
return (collection_key in include_collection_keys and
collection_key not in omit_collection_keys)
meta_graph.import_scoped_meta_graph(
meta_graph_filename,
graph=newgraph,
import_scope=import_scope,
restore_collections_predicate=_restore_collections_predicate)
collection_values = [
newgraph.get_collection(name=key, scope=import_scope)
for key in include_collection_keys
]
self.assertTrue(all(collection_values))
collection_values = [
newgraph.get_collection(name=key, scope=import_scope)
for key in omit_collection_keys
]
self.assertFalse(any(collection_values))
_test_import(
include_collection_keys=[
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
],
omit_collection_keys=[])
_test_import(
include_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES],
omit_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES])
_test_import(
include_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES],
omit_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES])
_test_import(
include_collection_keys=[],
omit_collection_keys=[
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
])
def _testScopedExportWithQueue(self, test_dir, exported_filename):
graph = ops.Graph()
with graph.as_default():
with ops.name_scope("queue1"):
input_queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
enqueue = input_queue.enqueue((9876), name="enqueue")
close = input_queue.close(name="close")
qr = queue_runner_impl.QueueRunner(input_queue, [enqueue], close)
queue_runner_impl.add_queue_runner(qr)
input_queue.dequeue(name="dequeue")
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, exported_filename),
graph=ops.get_default_graph(),
export_scope="queue1")
return orig_meta_graph
def _testScopedImportWithQueue(self, test_dir, exported_filename,
new_exported_filename):
graph = ops.Graph()
meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filename),
graph=graph,
import_scope="new_queue1")
graph.as_graph_element("new_queue1/dequeue:0")
graph.as_graph_element("new_queue1/close")
with graph.as_default():
new_meta_graph, _ = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, new_exported_filename),
graph=graph,
export_scope="new_queue1")
return new_meta_graph
# Verifies that we can export the subgraph containing a FIFOQueue under
# "queue1" and import it into "new_queue1" in a new graph.
@test_util.run_deprecated_v1
def testScopedWithQueue(self):
test_dir = _TestDir("scoped_with_queue")
orig_meta_graph = self._testScopedExportWithQueue(test_dir,
"exported_queue1.pbtxt")
new_meta_graph = self._testScopedImportWithQueue(
test_dir, "exported_queue1.pbtxt", "exported_new_queue1.pbtxt")
test_util.assert_meta_graph_protos_equal(self, orig_meta_graph,
new_meta_graph)
def testExportDebugInfo(self):
graph1 = ops.Graph()
with graph1.as_default():
with ops.name_scope("hidden1/hidden2/hidden3"):
images = constant_op.constant(
1.0, dtypes.float32, shape=[3, 2], name="images")
weights1 = variables.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
name="weights")
biases1 = resource_variable_ops.ResourceVariable(
[0.1] * 3, name="biases")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
func_named_operations = []
for op in graph1.get_operations():
func_named_operations.append(("", op))
debug_info_def = error_interpolation.create_graph_debug_info_def(
func_named_operations)
# The unique file names in all the stack traces should be larger or equal
# than 1.
self.assertTrue(len(debug_info_def.files) >= 1)
# All the nodes from the exported graphdef are included.
self.assertEqual(len(debug_info_def.traces), len(graph1.get_operations()))
# Verifies that we can export a subgraph in a nested name scope containing a
# "hidden1/hidden2" and import it into "new_hidden1/new_hidden2" in a new
# graph.
def doTestExportNestedNames(self, use_resource=False):
graph1 = ops.Graph()
with graph1.as_default():
with ops.name_scope("hidden1/hidden2/hidden3"):
images = constant_op.constant(
1.0, dtypes.float32, shape=[3, 2], name="images")
if use_resource:
weights1 = variables.Variable(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
biases1 = resource_variable_ops.ResourceVariable(
[0.1] * 3, name="biases")
else:
biases1 = variables.Variable([0.1] * 3, name="biases")
weights1 = variables.Variable(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
export_scope="hidden1/hidden2", graph=graph1)
var_names = [v.name for _, v in var_list.items()]
self.assertEqual(["hidden3/biases:0", "hidden3/weights:0"],
sorted(var_list.keys()))
self.assertEqual([
"hidden1/hidden2/hidden3/biases:0", "hidden1/hidden2/hidden3/weights:0"
], sorted(var_names))
for node in orig_meta_graph.graph_def.node:
self.assertTrue(node.name.startswith("hidden3"))
graph2 = ops.Graph()
new_var_list = meta_graph.import_scoped_meta_graph(
orig_meta_graph, import_scope="new_hidden1/new_hidden2", graph=graph2)
self.assertEqual(["hidden3/biases:0", "hidden3/weights:0"],
sorted(new_var_list.keys()))
new_var_names = [v.name for _, v in new_var_list.items()]
self.assertEqual([
"new_hidden1/new_hidden2/hidden3/biases:0",
"new_hidden1/new_hidden2/hidden3/weights:0"
], sorted(new_var_names))
nodes = [
"new_hidden1/new_hidden2/hidden3/biases/Assign",
"new_hidden1/new_hidden2/hidden3/weights/Assign"
]
expected = [
b"loc:@new_hidden1/new_hidden2/hidden3/biases",
b"loc:@new_hidden1/new_hidden2/hidden3/weights"
]
@test_util.run_deprecated_v1
def testExportNestedNames(self):
self.doTestExportNestedNames(use_resource=False)
@test_util.run_deprecated_v1
def testExportNestedNamesResource(self):
self.doTestExportNestedNames(use_resource=True)
@test_util.run_deprecated_v1
def testPotentialCycle(self):
graph1 = ops.Graph()
with graph1.as_default():
a = constant_op.constant(1.0, shape=[2, 2])
b = constant_op.constant(2.0, shape=[2, 2])
matmul = math_ops.matmul(a, b)
with ops.name_scope("hidden1"):
c = nn_ops.relu(matmul)
d = constant_op.constant(3.0, shape=[2, 2])
matmul = math_ops.matmul(c, d)
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
export_scope="hidden1", graph=graph1)
graph2 = ops.Graph()
with graph2.as_default():
with self.assertRaisesRegex(ValueError, "Graph contains unbound inputs"):
meta_graph.import_scoped_meta_graph(
orig_meta_graph, import_scope="new_hidden1")
meta_graph.import_scoped_meta_graph(
orig_meta_graph,
import_scope="new_hidden1",
input_map={
"$unbound_inputs_MatMul": constant_op.constant(
4.0, shape=[2, 2])
})
@test_util.run_deprecated_v1
def testClearDevices(self):
graph1 = ops.Graph()
with graph1.as_default():
with ops.device("/device:CPU:0"):
a = variables.Variable(
constant_op.constant(
1.0, shape=[2, 2]), name="a")
with ops.device("/job:ps/replica:0/task:0/device:GPU:0"):
b = variables.Variable(
constant_op.constant(
2.0, shape=[2, 2]), name="b")
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
math_ops.matmul(a, b, name="matmul")
self.assertEqual("/device:CPU:0", str(graph1.as_graph_element("a").device))
self.assertEqual("/job:ps/replica:0/task:0/device:GPU:0",
str(graph1.as_graph_element("b").device))
self.assertEqual("/job:localhost/replica:0/task:0/device:CPU:0",
str(graph1.as_graph_element("matmul").device))
# Verifies that devices are cleared on export.
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
graph=graph1, clear_devices=True)
graph2 = ops.Graph()
with graph2.as_default():
meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=False)
self.assertEqual("", str(graph2.as_graph_element("a").device))
self.assertEqual("", str(graph2.as_graph_element("b").device))
self.assertEqual("", str(graph2.as_graph_element("matmul").device))
# Verifies that devices are cleared on export when passing in graph_def.
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
graph_def=graph1.as_graph_def(), clear_devices=True)
graph2 = ops.Graph()
with graph2.as_default():
meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=False)
self.assertEqual("", str(graph2.as_graph_element("a").device))
self.assertEqual("", str(graph2.as_graph_element("b").device))
self.assertEqual("", str(graph2.as_graph_element("matmul").device))
# Verifies that devices are cleared on import.
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
graph=graph1, clear_devices=False)
graph2 = ops.Graph()
with graph2.as_default():
meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=True)
self.assertEqual("", str(graph2.as_graph_element("a").device))
self.assertEqual("", str(graph2.as_graph_element("b").device))
self.assertEqual("", str(graph2.as_graph_element("matmul").device))
class MetaGraphWithVariableScopeTest(test.TestCase):
@test_util.run_deprecated_v1
def testMetricsCollection(self):
def _enqueue_vector(sess, queue, values, shape=None):
if not shape:
shape = (1, len(values))
dtype = queue.dtypes[0]
sess.run(
queue.enqueue(constant_op.constant(
values, dtype=dtype, shape=shape)))
meta_graph_filename = os.path.join(
_TestDir("metrics_export"), "meta_graph.pb")
graph = ops.Graph()
with self.session(graph=graph) as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
_enqueue_vector(sess, values_queue, [-4.2, 9.1])
_enqueue_vector(sess, values_queue, [6.5, 0])
_enqueue_vector(sess, values_queue, [-3.2, 4.0])
values = values_queue.dequeue()
_, update_op = metrics.mean(values)
initializer = variables.local_variables_initializer()
self.evaluate(initializer)
self.evaluate(update_op)
meta_graph.export_scoped_meta_graph(
filename=meta_graph_filename, graph=graph)
# Verifies that importing a meta_graph with LOCAL_VARIABLES collection
# works correctly.
graph = ops.Graph()
with self.session(graph=graph) as sess:
meta_graph.import_scoped_meta_graph(meta_graph_filename)
initializer = variables.local_variables_initializer()
self.evaluate(initializer)
# Verifies that importing an old meta_graph where "local_variables"
# collection is of node_list type works, but cannot build initializer
# with the collection.
graph = ops.Graph()
with self.session(graph=graph) as sess:
meta_graph.import_scoped_meta_graph(
test.test_src_dir_path(
"python/framework/testdata/metrics_export_meta_graph.pb"))
self.assertEqual(len(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)),
2)
with self.assertRaisesRegex(
AttributeError, "'Tensor' object has no attribute 'initializer'"):
initializer = variables.local_variables_initializer()
class ExportImportAcrossScopesTest(test.TestCase):
@test_util.run_deprecated_v1
def testPartitionedVariables(self):
def make_graph_with_partitioned_variables(use_resource):
variable_scope.get_variable(
name="weights",
partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0),
initializer=random_ops.truncated_normal([100, 10]),
use_resource=use_resource)
# The next variable illustrates the necessity of restoring collections
# in a deterministic fashion when using ResourceVariables.
variable_scope.get_variable(
name="another",
shape=[],
collections=["a", "b", "z", "f", "e", "d", "g"],
use_resource=use_resource)
self._testExportImportAcrossScopes(
make_graph_with_partitioned_variables, use_resource=False)
self._testExportImportAcrossScopes(
make_graph_with_partitioned_variables, use_resource=True)
def _testExportImportAcrossScopes(self, graph_fn, use_resource):
"""Tests export and importing a graph across scopes.
Args:
graph_fn: A closure that creates a graph on the current scope.
use_resource: A bool indicating whether or not to use ResourceVariables.
"""
with ops.Graph().as_default() as original_graph:
with variable_scope.variable_scope("dropA/dropB/keepA"):
graph_fn(use_resource=use_resource)
exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
graph=original_graph,
export_scope="dropA/dropB")[0]
with ops.Graph().as_default() as imported_graph:
meta_graph.import_scoped_meta_graph(
exported_meta_graph_def,
import_scope="importA")
with ops.Graph().as_default() as expected_graph:
with variable_scope.variable_scope("importA/keepA"):
graph_fn(use_resource=use_resource)
result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]
if use_resource:
# Clear all shared_name attributes before comparing, since they are
# orthogonal to scopes and are not updated on export/import.
for meta_graph_def in [result, expected]:
for node in meta_graph_def.graph_def.node:
shared_name_attr = "shared_name"
shared_name_value = node.attr.get(shared_name_attr, None)
if shared_name_value and shared_name_value.HasField("s"):
if shared_name_value.s:
node.attr[shared_name_attr].s = b""
test_util.assert_meta_graph_protos_equal(self, expected, result)
if __name__ == "__main__":
test.main()