Support limited forward compatibility when importing a MetaGraphDef.
Change: 121880284
This commit is contained in:
parent
bf7e5fe193
commit
6e02bf0299
@ -23,8 +23,8 @@ import contextlib
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.util import compat
|
||||
@ -142,8 +142,15 @@ def _MaybeDevice(device):
|
||||
yield
|
||||
|
||||
|
||||
def _FindAttrInOpDef(attr_name, op_def):
|
||||
for attr_def in op_def.attr:
|
||||
if attr_name == attr_def.name:
|
||||
return attr_def
|
||||
return None
|
||||
|
||||
|
||||
def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
name=None, op_dict=None):
|
||||
name=None, op_dict=None, producer_op_list=None):
|
||||
"""Imports the TensorFlow graph in `graph_def` into the Python `Graph`.
|
||||
|
||||
This function provides a way to import a serialized TensorFlow
|
||||
@ -167,6 +174,12 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
|
||||
Must contain an `OpDef` proto for each op type named in `graph_def`.
|
||||
If omitted, uses the `OpDef` protos registered in the global registry.
|
||||
producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
|
||||
list of `OpDef`s used by the producer of the graph. If provided, attrs
|
||||
for ops in `graph_def` that are not in `op_dict` that have their default
|
||||
value according to `producer_op_list` will be removed. This will allow
|
||||
some more `GraphDef`s produced by later binaries to be accepted by
|
||||
earlier binaries.
|
||||
|
||||
Returns:
|
||||
A list of `Operation` and/or `Tensor` objects from the imported graph,
|
||||
@ -213,6 +226,11 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
if op_dict is None:
|
||||
op_dict = op_def_registry.get_registered_ops()
|
||||
|
||||
if producer_op_list is None:
|
||||
producer_op_dict = None
|
||||
else:
|
||||
producer_op_dict = {op.name: op for op in producer_op_list.op}
|
||||
|
||||
with ops.op_scope(input_map.values(), name, 'import'):
|
||||
g = ops.get_default_graph()
|
||||
g.graph_def_versions.CopyFrom(graph_def.versions)
|
||||
@ -233,6 +251,21 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
value = node.attr[key]
|
||||
if value is None or value.WhichOneof('value') is None:
|
||||
node.attr[key].CopyFrom(attr_def.default_value)
|
||||
if producer_op_dict:
|
||||
# Remove any default attr values that aren't in op_def.
|
||||
if node.op in producer_op_dict:
|
||||
producer_op_def = producer_op_dict[node.op]
|
||||
# We make a copy of node.attr to iterate through since we
|
||||
# may modify node.attr inside the loop.
|
||||
for key in list(node.attr):
|
||||
if _FindAttrInOpDef(key, op_def) is None:
|
||||
# No attr_def in consumer, look in producer.
|
||||
attr_def = _FindAttrInOpDef(key, producer_op_def)
|
||||
if (attr_def and attr_def.HasField('default_value') and
|
||||
node.attr[key] == attr_def.default_value):
|
||||
# Unknown attr had default value in producer, delete it
|
||||
# so it can be understood by consumer.
|
||||
del node.attr[key]
|
||||
|
||||
output_types = _OutputTypes(node, op_dict)
|
||||
name_to_op[node.name] = g.create_op(
|
||||
@ -326,8 +359,8 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
_InvalidNodeMessage(
|
||||
node,
|
||||
'Input types mismatch (expected %r but got %r)'
|
||||
% (", ".join(dtypes.as_dtype(x).name for x in input_types),
|
||||
", ".join(x.name for x in op._input_dtypes))))
|
||||
% (', '.join(dtypes.as_dtype(x).name for x in input_types),
|
||||
', '.join(x.name for x in op._input_dtypes))))
|
||||
# pylint: enable=protected_access
|
||||
|
||||
# Execute shape inference for this op.
|
||||
|
||||
@ -106,12 +106,16 @@ text_format.Merge("""
|
||||
output_arg { name: 'a' type: DT_INT32 }
|
||||
attr { name: 'default_float' type: 'float' default_value { f: 123.0 } }
|
||||
}
|
||||
op {
|
||||
name: 'OpWithFutureDefaultAttr'
|
||||
}
|
||||
""", _op_list)
|
||||
op_def_registry.register_op_list(_op_list)
|
||||
# NOTE(mrry): Dummy shape registrations for ops used in the tests.
|
||||
for op_def in _op_list.op:
|
||||
tf.RegisterShape(op_def.name)(None)
|
||||
|
||||
|
||||
class ImportGraphDefTest(tf.test.TestCase):
|
||||
|
||||
def _MakeGraphDef(self, text, producer=tf.GRAPH_DEF_VERSION,
|
||||
@ -139,8 +143,8 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
attr { key: 'T' value { type: DT_FLOAT } }
|
||||
input: 'A:1' input: 'B:1' }
|
||||
"""),
|
||||
return_elements=['A', 'B', 'C', 'D'],
|
||||
name='import')
|
||||
return_elements=["A", "B", "C", "D"],
|
||||
name="import")
|
||||
|
||||
# Assert that the import process creates distinct tensors.
|
||||
self.assertNotEqual(a.outputs[0].name, a.outputs[1].name)
|
||||
@ -157,20 +161,20 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
self.assertEqual(d.inputs[1], b.outputs[1])
|
||||
|
||||
# Check the types of the returned ops and tensors.
|
||||
self.assertEqual(a.type, 'Oif')
|
||||
self.assertEqual(b.type, 'Otl')
|
||||
self.assertEqual(c.type, 'In')
|
||||
self.assertEqual(d.type, 'In')
|
||||
self.assertEqual(a.type, "Oif")
|
||||
self.assertEqual(b.type, "Otl")
|
||||
self.assertEqual(c.type, "In")
|
||||
self.assertEqual(d.type, "In")
|
||||
self.assertEqual(a.outputs[0].dtype, tf.int32)
|
||||
self.assertEqual(a.outputs[1].dtype, tf.float32)
|
||||
self.assertEqual(b.outputs[0].dtype, tf.int32)
|
||||
self.assertEqual(b.outputs[1].dtype, tf.float32)
|
||||
|
||||
# Check the names of the returned ops.
|
||||
self.assertEqual(a.name, 'import/A')
|
||||
self.assertEqual(b.name, 'import/B')
|
||||
self.assertEqual(c.name, 'import/C')
|
||||
self.assertEqual(d.name, 'import/D')
|
||||
self.assertEqual(a.name, "import/A")
|
||||
self.assertEqual(b.name, "import/B")
|
||||
self.assertEqual(c.name, "import/C")
|
||||
self.assertEqual(d.name, "import/D")
|
||||
|
||||
# Check that the op_def is still available.
|
||||
self.assertNotEqual(None, a.op_def)
|
||||
@ -193,8 +197,8 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
attr { key: 'T' value { type: DT_INT32 } }
|
||||
input: 'A:1' input: 'B:1' }
|
||||
"""),
|
||||
input_map={'A:0': feed_a_0, 'B:1': feed_b_1},
|
||||
return_elements=['A', 'B', 'C', 'D'])
|
||||
input_map={"A:0": feed_a_0, "B:1": feed_b_1},
|
||||
return_elements=["A", "B", "C", "D"])
|
||||
|
||||
self.assertEqual(c.inputs[0], feed_a_0)
|
||||
self.assertEqual(c.inputs[1], b.outputs[0])
|
||||
@ -219,8 +223,8 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
attr { key: 'T' value { type: DT_INT32 } }
|
||||
input: 'A:1' input: 'B:1' }
|
||||
"""),
|
||||
input_map={b'A:0': feed_a_0, b'B:1': feed_b_1},
|
||||
return_elements=[b'A', b'B', b'C', b'D'])
|
||||
input_map={b"A:0": feed_a_0, b"B:1": feed_b_1},
|
||||
return_elements=[b"A", b"B", b"C", b"D"])
|
||||
|
||||
self.assertEqual(c.inputs[0], feed_a_0)
|
||||
self.assertEqual(c.inputs[1], b.outputs[0])
|
||||
@ -245,8 +249,8 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
attr { key: 'T' value { type: DT_INT32 } }
|
||||
input: 'A:1' input: 'B:1' }
|
||||
"""),
|
||||
input_map={u'A:0': feed_a_0, u'B:1': feed_b_1},
|
||||
return_elements=[u'A', u'B', u'C', u'D'])
|
||||
input_map={u"A:0": feed_a_0, u"B:1": feed_b_1},
|
||||
return_elements=[u"A", u"B", u"C", u"D"])
|
||||
|
||||
self.assertEqual(c.inputs[0], feed_a_0)
|
||||
self.assertEqual(c.inputs[1], b.outputs[0])
|
||||
@ -260,7 +264,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
node { name: 'A' op: 'Oii' }
|
||||
node { name: 'B' op: 'Ii' input: 'A' }
|
||||
"""),
|
||||
return_elements=['A', 'B'])
|
||||
return_elements=["A", "B"])
|
||||
|
||||
self.assertEqual(b.inputs[0], a.outputs[0])
|
||||
|
||||
@ -272,8 +276,8 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
node { name: 'A' op: 'Oii' }
|
||||
node { name: 'B' op: 'Ii' input: 'A:0' }
|
||||
"""),
|
||||
input_map={'A': feed_a_0},
|
||||
return_elements=['B'])
|
||||
input_map={"A": feed_a_0},
|
||||
return_elements=["B"])
|
||||
|
||||
self.assertEqual(b.inputs[0], feed_a_0)
|
||||
|
||||
@ -284,7 +288,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
node { name: 'A' op: 'None' }
|
||||
node { name: 'B' op: 'None' input: '^A' }
|
||||
"""),
|
||||
return_elements=['A', 'B'])
|
||||
return_elements=["A", "B"])
|
||||
|
||||
self.assertEqual(b.control_inputs, [a])
|
||||
|
||||
@ -297,7 +301,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
node { name: 'C' op: 'Iii' input: 'A:0' input: 'B:0' }
|
||||
node { name: 'D' op: 'Iri' input: 'A:0' input: 'B:0' }
|
||||
"""),
|
||||
return_elements=['A', 'B', 'C', 'D'])
|
||||
return_elements=["A", "B", "C", "D"])
|
||||
|
||||
self.assertEqual(c.inputs[0], a.outputs[0])
|
||||
self.assertEqual(c.inputs[1], b.outputs[0])
|
||||
@ -320,7 +324,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
node { name: 'B' op: 'Unary'
|
||||
attr { key: 'T' value { type: DT_INT32 } } input: 'A:0' }
|
||||
"""),
|
||||
return_elements=['A', 'B'])
|
||||
return_elements=["A", "B"])
|
||||
|
||||
self.assertEqual(a.inputs[0], b.outputs[0])
|
||||
self.assertEqual(b.inputs[0], a.outputs[0])
|
||||
@ -334,7 +338,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
node { name: 'B' op: 'If' input: 'A:0' }
|
||||
"""))
|
||||
self.assertTrue(
|
||||
'Cannot convert a tensor of type int32 to an input of type float' in
|
||||
"Cannot convert a tensor of type int32 to an input of type float" in
|
||||
str(e.exception))
|
||||
|
||||
def testInvalidSignatureTooManyInputsInGraphDef(self):
|
||||
@ -345,7 +349,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
node { name: 'A' op: 'Oi' }
|
||||
node { name: 'B' op: 'None' input: 'A:0' }
|
||||
"""))
|
||||
self.assertTrue('More inputs specified (\'A:0\') than the op expects' in
|
||||
self.assertTrue("More inputs specified ('A:0') than the op expects" in
|
||||
str(e.exception))
|
||||
|
||||
def testInvalidSignatureNotEnoughInputsInGraphDef(self):
|
||||
@ -356,8 +360,8 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
node { name: 'A' op: 'Oi' }
|
||||
node { name: 'B' op: 'Iif' input: 'A:0' }
|
||||
"""))
|
||||
self.assertTrue('Input types mismatch (expected \'int32, float32\' but '
|
||||
'got \'int32\')' in str(e.exception))
|
||||
self.assertTrue("Input types mismatch (expected 'int32, float32' but "
|
||||
"got 'int32')" in str(e.exception))
|
||||
|
||||
def testMissingInputOpInGraphDef(self):
|
||||
with tf.Graph().as_default():
|
||||
@ -375,8 +379,8 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'B' op: 'If' input: 'A:0' }
|
||||
"""),
|
||||
input_map={'A:0': feed_a_0},
|
||||
return_elements=['B'])
|
||||
input_map={"A:0": feed_a_0},
|
||||
return_elements=["B"])
|
||||
self.assertEqual(b.inputs[0], feed_a_0)
|
||||
|
||||
def testMissingInputTensorInGraphDef(self):
|
||||
@ -425,7 +429,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'None' }
|
||||
"""),
|
||||
return_elements=['B'])
|
||||
return_elements=["B"])
|
||||
self.assertTrue("return_element 'B' not found in graph_def." in
|
||||
str(e.exception))
|
||||
|
||||
@ -436,7 +440,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'Oi' }
|
||||
"""),
|
||||
return_elements=['A:1'])
|
||||
return_elements=["A:1"])
|
||||
self.assertTrue("return_element 'A:1' not found in graph_def." in
|
||||
str(e.exception))
|
||||
|
||||
@ -445,7 +449,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'Oi' }
|
||||
"""),
|
||||
return_elements=['B:0'])
|
||||
return_elements=["B:0"])
|
||||
self.assertTrue("return_element 'B:0' not found in graph_def." in
|
||||
str(e.exception))
|
||||
|
||||
@ -454,7 +458,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'Oi' }
|
||||
"""),
|
||||
return_elements=['A:B:0'])
|
||||
return_elements=["A:B:0"])
|
||||
self.assertTrue("return_element 'A:B:0' not found in graph_def." in
|
||||
str(e.exception))
|
||||
|
||||
@ -465,8 +469,8 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'None' }
|
||||
"""),
|
||||
input_map={'B:0': tf.constant(5.0)})
|
||||
self.assertTrue('not found in graph_def: [B:0]' in str(e.exception))
|
||||
input_map={"B:0": tf.constant(5.0)})
|
||||
self.assertTrue("not found in graph_def: [B:0]" in str(e.exception))
|
||||
|
||||
def testInputMapTypeMismatch(self):
|
||||
with tf.Graph().as_default():
|
||||
@ -476,9 +480,9 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
node { name: 'A' op: 'Oi' }
|
||||
node { name: 'B' op: 'Ii' input: 'A:0' }
|
||||
"""),
|
||||
input_map={'A:0': tf.constant(5.0)})
|
||||
input_map={"A:0": tf.constant(5.0)})
|
||||
self.assertTrue(
|
||||
'Cannot convert a tensor of type float32 to an input of type int32.'
|
||||
"Cannot convert a tensor of type float32 to an input of type int32."
|
||||
in str(e.exception))
|
||||
|
||||
def testNoReturns(self):
|
||||
@ -489,8 +493,8 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
"""))
|
||||
self.assertEqual(ret, None)
|
||||
|
||||
a = g.get_operation_by_name('import/A')
|
||||
self.assertEqual(a.type, 'None')
|
||||
a = g.get_operation_by_name("import/A")
|
||||
self.assertEqual(a.type, "None")
|
||||
|
||||
def testOverrideNamePrefix(self):
|
||||
with tf.Graph().as_default():
|
||||
@ -498,8 +502,8 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'None' }
|
||||
"""),
|
||||
return_elements=['A'], name='imported_graph')
|
||||
self.assertEqual(a.name, 'imported_graph/A')
|
||||
return_elements=["A"], name="imported_graph")
|
||||
self.assertEqual(a.name, "imported_graph/A")
|
||||
|
||||
def testNamePrefixColocationAttrs(self):
|
||||
original_graph_def = self._MakeGraphDef("""
|
||||
@ -511,7 +515,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
|
||||
with tf.Graph().as_default():
|
||||
b, = tf.import_graph_def(original_graph_def,
|
||||
return_elements=['B'], name='imported_graph')
|
||||
return_elements=["B"], name="imported_graph")
|
||||
self.assertProtoEqualsVersion("""
|
||||
node { name: 'imported_graph/A' op: 'None' }
|
||||
node { name: 'imported_graph/B' op: 'None' attr {
|
||||
@ -529,9 +533,9 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
|
||||
with tf.Graph().as_default():
|
||||
b, = tf.import_graph_def(original_graph_def,
|
||||
return_elements=['B'], name='')
|
||||
return_elements=["B"], name="")
|
||||
_, = tf.import_graph_def(original_graph_def,
|
||||
return_elements=['B'], name='')
|
||||
return_elements=["B"], name="")
|
||||
self.assertProtoEqualsVersion("""
|
||||
node { name: 'A' op: 'None' }
|
||||
node { name: 'B' op: 'None' attr {
|
||||
@ -551,90 +555,90 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
value { list { s: 'loc:@A' } }
|
||||
} }""")
|
||||
with tf.Graph().as_default():
|
||||
with self.assertRaisesRegexp(ValueError, 'does not exist during import'):
|
||||
with self.assertRaisesRegexp(ValueError, "does not exist during import"):
|
||||
tf.import_graph_def(original_graph_def,
|
||||
return_elements=['B'], name='imported_graph')
|
||||
return_elements=["B"], name="imported_graph")
|
||||
|
||||
def testEmptyGraph(self):
|
||||
with tf.Graph().as_default() as g:
|
||||
init_version = g.version
|
||||
tf.import_graph_def(self._MakeGraphDef(''))
|
||||
tf.import_graph_def(self._MakeGraphDef(""))
|
||||
self.assertEqual(init_version, g.version)
|
||||
|
||||
def testInvalidInputForGraphDef(self):
|
||||
with tf.Graph().as_default():
|
||||
with self.assertRaises(TypeError) as e:
|
||||
tf.import_graph_def('')
|
||||
tf.import_graph_def("")
|
||||
self.assertEqual(
|
||||
'graph_def must be a GraphDef proto.', str(e.exception))
|
||||
"graph_def must be a GraphDef proto.", str(e.exception))
|
||||
|
||||
def testInvalidInputForInputMap(self):
|
||||
with tf.Graph().as_default():
|
||||
with self.assertRaises(TypeError) as e:
|
||||
tf.import_graph_def(self._MakeGraphDef(''),
|
||||
input_map=[tf.constant(5.0)])
|
||||
self.assertEqual('input_map must be a dictionary mapping strings to '
|
||||
'Tensor objects.', str(e.exception))
|
||||
tf.import_graph_def(self._MakeGraphDef(""),
|
||||
input_map=[tf.constant(5.0)])
|
||||
self.assertEqual("input_map must be a dictionary mapping strings to "
|
||||
"Tensor objects.", str(e.exception))
|
||||
|
||||
def testInvalidInputForReturnOperations(self):
|
||||
with tf.Graph().as_default():
|
||||
with self.assertRaises(TypeError) as e:
|
||||
tf.import_graph_def(self._MakeGraphDef(''), return_elements=[7])
|
||||
tf.import_graph_def(self._MakeGraphDef(""), return_elements=[7])
|
||||
self.assertEqual(
|
||||
'return_elements must be a list of strings.', str(e.exception))
|
||||
"return_elements must be a list of strings.", str(e.exception))
|
||||
|
||||
def testWithExtensionAndAttr(self):
|
||||
with tf.Graph().as_default() as g:
|
||||
c = tf.constant(5.0, dtype=tf.float32, name='c')
|
||||
tf.pack([c, c], name='pack')
|
||||
c = tf.constant(5.0, dtype=tf.float32, name="c")
|
||||
tf.pack([c, c], name="pack")
|
||||
gdef = g.as_graph_def()
|
||||
|
||||
with self.test_session():
|
||||
pack, = tf.import_graph_def(gdef, return_elements=['pack'])
|
||||
pack, = tf.import_graph_def(gdef, return_elements=["pack"])
|
||||
self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0])
|
||||
|
||||
def testWithDevice(self):
|
||||
with tf.Graph().as_default() as g:
|
||||
# No device.
|
||||
a = tf.constant(3.0, name='a')
|
||||
a = tf.constant(3.0, name="a")
|
||||
|
||||
with tf.device('/cpu:0'):
|
||||
b = tf.constant(4.0, name='b')
|
||||
with tf.device('/job:worker'):
|
||||
c = tf.constant(5.0, name='c')
|
||||
with tf.device("/cpu:0"):
|
||||
b = tf.constant(4.0, name="b")
|
||||
with tf.device("/job:worker"):
|
||||
c = tf.constant(5.0, name="c")
|
||||
|
||||
gdef = g.as_graph_def()
|
||||
|
||||
with tf.Graph().as_default():
|
||||
a2, b2, c2 = tf.import_graph_def(
|
||||
gdef, return_elements=['a', 'b', 'c'])
|
||||
gdef, return_elements=["a", "b", "c"])
|
||||
self.assertEqual(a.device, a2.device)
|
||||
self.assertEqual(b.device, b2.device)
|
||||
self.assertEqual(c.device, c2.device)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
with tf.device(device.merge_device('/task:0')):
|
||||
with tf.device(device.merge_device("/task:0")):
|
||||
a3, b3, c3 = tf.import_graph_def(
|
||||
gdef, return_elements=['a', 'b', 'c'])
|
||||
self.assertEqual('/task:0', a3.device)
|
||||
self.assertEqual('/task:0/device:CPU:0', b3.device) # canonicalized.
|
||||
self.assertEqual(c.device + '/task:0', c3.device)
|
||||
gdef, return_elements=["a", "b", "c"])
|
||||
self.assertEqual("/task:0", a3.device)
|
||||
self.assertEqual("/task:0/device:CPU:0", b3.device) # canonicalized.
|
||||
self.assertEqual(c.device + "/task:0", c3.device)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
with tf.device(device.merge_device('/job:ps')):
|
||||
with tf.device(device.merge_device("/job:ps")):
|
||||
a4, b4, c4 = tf.import_graph_def(
|
||||
gdef, return_elements=['a', 'b', 'c'])
|
||||
self.assertEqual('/job:ps', a4.device)
|
||||
self.assertEqual('/job:ps/device:CPU:0', b4.device) # canonicalized.
|
||||
gdef, return_elements=["a", "b", "c"])
|
||||
self.assertEqual("/job:ps", a4.device)
|
||||
self.assertEqual("/job:ps/device:CPU:0", b4.device) # canonicalized.
|
||||
self.assertEqual(c.device, c4.device) # worker overrides ps.
|
||||
|
||||
with tf.Graph().as_default():
|
||||
with tf.device(device.merge_device('/gpu:0')):
|
||||
with tf.device(device.merge_device("/gpu:0")):
|
||||
a5, b5, c5 = tf.import_graph_def(
|
||||
gdef, return_elements=['a', 'b', 'c'])
|
||||
self.assertEqual('/device:GPU:0', a5.device)
|
||||
self.assertEqual('/device:CPU:0', b5.device) # cpu overrides gpu.
|
||||
self.assertEqual(c.device + '/device:GPU:0', c5.device)
|
||||
gdef, return_elements=["a", "b", "c"])
|
||||
self.assertEqual("/device:GPU:0", a5.device)
|
||||
self.assertEqual("/device:CPU:0", b5.device) # cpu overrides gpu.
|
||||
self.assertEqual(c.device + "/device:GPU:0", c5.device)
|
||||
|
||||
def testWithDeviceFunctionDependingOnInputs(self):
|
||||
with tf.Graph().as_default() as g:
|
||||
@ -706,7 +710,7 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
a, = tf.import_graph_def(
|
||||
self._MakeGraphDef("node { name: 'A' op: 'Oii' }",
|
||||
producer=producer, min_consumer=min_consumer),
|
||||
return_elements=['A'])
|
||||
return_elements=["A"])
|
||||
self.assertEqual(a.graph.graph_def_versions.producer, producer)
|
||||
self.assertEqual(a.graph.graph_def_versions.min_consumer,
|
||||
min_consumer)
|
||||
@ -739,9 +743,38 @@ class ImportGraphDefTest(tf.test.TestCase):
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'OpWithDefaultAttr' }
|
||||
"""),
|
||||
return_elements=['A'])
|
||||
return_elements=["A"])
|
||||
self.assertEqual(123.0, a[0].get_attr("default_float"))
|
||||
|
||||
def testDefaultAttrsRemoved(self):
|
||||
producer_op_list = op_def_pb2.OpList()
|
||||
text_format.Merge("""
|
||||
op {
|
||||
name: 'OpWithFutureDefaultAttr'
|
||||
attr { name: 'default_int' type: 'int' default_value { i: 456 } }
|
||||
}
|
||||
""", producer_op_list)
|
||||
# Attr only in producer_op_list with default value gets removed.
|
||||
with tf.Graph().as_default():
|
||||
a = tf.import_graph_def(
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'OpWithFutureDefaultAttr'
|
||||
attr { key: 'default_int' value { i: 456 } } }
|
||||
"""),
|
||||
return_elements=["A"], producer_op_list=producer_op_list)
|
||||
with self.assertRaisesRegexp(ValueError, "No attr named 'default_int'"):
|
||||
a[0].get_attr("default_int")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Attr only in producer_op_list with non-default value is preserved.
|
||||
with tf.Graph().as_default():
|
||||
a = tf.import_graph_def(
|
||||
self._MakeGraphDef("""
|
||||
node { name: 'A' op: 'OpWithFutureDefaultAttr'
|
||||
attr { key: 'default_int' value { i: 987 } } }
|
||||
"""),
|
||||
return_elements=["A"], producer_op_list=producer_op_list)
|
||||
self.assertEqual(987, a[0].get_attr("default_int"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
||||
@ -1317,7 +1317,11 @@ def _import_meta_graph_def(meta_graph_def):
|
||||
(i.e., no variables to restore).
|
||||
"""
|
||||
# Gathers the list of nodes we are interested in.
|
||||
importer.import_graph_def(meta_graph_def.graph_def, name="")
|
||||
producer_op_list = None
|
||||
if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
|
||||
producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
|
||||
importer.import_graph_def(meta_graph_def.graph_def, name="",
|
||||
producer_op_list=producer_op_list)
|
||||
|
||||
# Restores all the other collections.
|
||||
for key, col_def in meta_graph_def.collection_def.items():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user