Pass all the arguments at once in the generator to fetch_type() in _DictFetchMapper and ensure the same default_factory for collections.defaultdict type.
PiperOrigin-RevId: 330000469 Change-Id: Id09a5f79528164b18f5b78a31ee1e13c3ca1518d
This commit is contained in:
parent
2cea1ba130
commit
8320d33533
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import re
|
||||
import threading
|
||||
@ -41,13 +42,14 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.experimental import mixed_precision_global_state
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
_python_session_create_counter = monitoring.Counter(
|
||||
'/tensorflow/api/python/session_create_counter',
|
||||
'Counter for number of sessions created in Python.')
|
||||
|
||||
|
||||
class SessionInterface(object):
|
||||
"""Base class for implementations of TensorFlow client sessions."""
|
||||
|
||||
@ -406,6 +408,12 @@ class _DictFetchMapper(_FetchMapper):
|
||||
fetches: Dict of fetches.
|
||||
"""
|
||||
self._fetch_type = type(fetches)
|
||||
if isinstance(fetches, collections.defaultdict):
|
||||
self._type_ctor = functools.partial(collections.defaultdict,
|
||||
fetches.default_factory)
|
||||
else:
|
||||
self._type_ctor = self._fetch_type
|
||||
|
||||
self._keys = fetches.keys()
|
||||
self._mappers = [
|
||||
_FetchMapper.for_fetch(fetch) for fetch in fetches.values()
|
||||
@ -416,10 +424,12 @@ class _DictFetchMapper(_FetchMapper):
|
||||
return self._unique_fetches
|
||||
|
||||
def build_results(self, values):
|
||||
results = self._fetch_type()
|
||||
for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
|
||||
results[k] = m.build_results([values[j] for j in vi])
|
||||
return results
|
||||
|
||||
def _generator():
|
||||
for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
|
||||
yield k, m.build_results([values[j] for j in vi])
|
||||
|
||||
return self._type_ctor(_generator())
|
||||
|
||||
|
||||
class _AttrsFetchMapper(_FetchMapper):
|
||||
|
@ -18,8 +18,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import random
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@ -68,6 +68,13 @@ try:
|
||||
except ImportError:
|
||||
attr = None
|
||||
|
||||
try:
|
||||
from frozendict import frozendict # pylint:disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
frozendict = dict # pylint:disable=invalid-name
|
||||
|
||||
defaultdict = collections.defaultdict # pylint:disable=invalid-name
|
||||
|
||||
|
||||
class SessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@ -222,13 +229,13 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
res = sess.run(a)
|
||||
self.assertEqual(42.0, res)
|
||||
res = sess.run(a.op) # An op, not a tensor.
|
||||
self.assertEqual(None, res)
|
||||
self.assertIsNone(res)
|
||||
tensor_runner = sess.make_callable(a)
|
||||
res = tensor_runner()
|
||||
self.assertEqual(42.0, res)
|
||||
op_runner = sess.make_callable(a.op)
|
||||
res = op_runner()
|
||||
self.assertEqual(None, res)
|
||||
self.assertIsNone(res)
|
||||
|
||||
def testFetchSingletonByName(self):
|
||||
with session.Session() as sess:
|
||||
@ -236,7 +243,7 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
res = sess.run(a.name)
|
||||
self.assertEqual(42.0, res)
|
||||
res = sess.run(a.op) # An op, not a tensor.
|
||||
self.assertEqual(None, res)
|
||||
self.assertIsNone(res)
|
||||
|
||||
def testFetchList(self):
|
||||
with session.Session() as sess:
|
||||
@ -246,11 +253,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
v = variables.Variable([54.0])
|
||||
assign = v.assign([63.0])
|
||||
res = sess.run([a, b, c, a.name, assign.op])
|
||||
self.assertTrue(isinstance(res, list))
|
||||
self.assertIsInstance(res, list)
|
||||
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
||||
list_runner = sess.make_callable([a, b, c, a.name, assign.op])
|
||||
res = list_runner()
|
||||
self.assertTrue(isinstance(res, list))
|
||||
self.assertIsInstance(res, list)
|
||||
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
||||
|
||||
def testFetchTuple(self):
|
||||
@ -259,11 +266,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
b = control_flow_ops.no_op() # An op, not a tensor.
|
||||
c = constant_op.constant(44.0)
|
||||
res = sess.run((a, b, c, a.name))
|
||||
self.assertTrue(isinstance(res, tuple))
|
||||
self.assertIsInstance(res, tuple)
|
||||
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
||||
tuple_runner = sess.make_callable((a, b, c, a.name))
|
||||
res = tuple_runner()
|
||||
self.assertTrue(isinstance(res, tuple))
|
||||
self.assertIsInstance(res, tuple)
|
||||
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
||||
|
||||
def testFetchNamedTuple(self):
|
||||
@ -275,15 +282,15 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
b = control_flow_ops.no_op() # An op, not a tensor.
|
||||
c = constant_op.constant(44.0)
|
||||
res = sess.run(ABC(a, b, c))
|
||||
self.assertTrue(isinstance(res, ABC))
|
||||
self.assertIsInstance(res, ABC)
|
||||
self.assertEqual(42.0, res.a)
|
||||
self.assertEqual(None, res.b)
|
||||
self.assertIsNone(res.b)
|
||||
self.assertEqual(44.0, res.c)
|
||||
namedtuple_runner = sess.make_callable(ABC(a, b, c))
|
||||
res = namedtuple_runner()
|
||||
self.assertTrue(isinstance(res, ABC))
|
||||
self.assertIsInstance(res, ABC)
|
||||
self.assertEqual(42.0, res.a)
|
||||
self.assertEqual(None, res.b)
|
||||
self.assertIsNone(res.b)
|
||||
self.assertEqual(44.0, res.c)
|
||||
|
||||
def testFetchDict(self):
|
||||
@ -292,9 +299,9 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
b = control_flow_ops.no_op() # An op, not a tensor.
|
||||
c = constant_op.constant(44.0)
|
||||
res = sess.run({'a': a, 'b': b, 'c': c})
|
||||
self.assertTrue(isinstance(res, dict))
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertEqual(42.0, res['a'])
|
||||
self.assertEqual(None, res['b'])
|
||||
self.assertIsNone(res['b'])
|
||||
self.assertEqual(44.0, res['c'])
|
||||
|
||||
def testFetchOrderedDict(self):
|
||||
@ -303,10 +310,10 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
b = control_flow_ops.no_op() # An op, not a tensor.
|
||||
c = constant_op.constant(44.0)
|
||||
res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)]))
|
||||
self.assertTrue(isinstance(res, collections.OrderedDict))
|
||||
self.assertIsInstance(res, collections.OrderedDict)
|
||||
self.assertEqual([3, 2, 1], list(res.keys()))
|
||||
self.assertEqual(42.0, res[3])
|
||||
self.assertEqual(None, res[2])
|
||||
self.assertIsNone(res[2])
|
||||
self.assertEqual(44.0, res[1])
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
@ -393,23 +400,23 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
a = constant_op.constant(a_val)
|
||||
|
||||
res = sess.run([[], tuple(), {}])
|
||||
self.assertTrue(isinstance(res, list))
|
||||
self.assertIsInstance(res, list)
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertTrue(isinstance(res[0], list))
|
||||
self.assertIsInstance(res[0], list)
|
||||
self.assertEqual(0, len(res[0]))
|
||||
self.assertTrue(isinstance(res[1], tuple))
|
||||
self.assertIsInstance(res[1], tuple)
|
||||
self.assertEqual(0, len(res[1]))
|
||||
self.assertTrue(isinstance(res[2], dict))
|
||||
self.assertIsInstance(res[2], dict)
|
||||
self.assertEqual(0, len(res[2]))
|
||||
|
||||
res = sess.run([[], tuple(), {}, a])
|
||||
self.assertTrue(isinstance(res, list))
|
||||
self.assertIsInstance(res, list)
|
||||
self.assertEqual(4, len(res))
|
||||
self.assertTrue(isinstance(res[0], list))
|
||||
self.assertIsInstance(res[0], list)
|
||||
self.assertEqual(0, len(res[0]))
|
||||
self.assertTrue(isinstance(res[1], tuple))
|
||||
self.assertIsInstance(res[1], tuple)
|
||||
self.assertEqual(0, len(res[1]))
|
||||
self.assertTrue(isinstance(res[2], dict))
|
||||
self.assertIsInstance(res[2], dict)
|
||||
self.assertEqual(0, len(res[2]))
|
||||
self.assertEqual(a_val, res[3])
|
||||
|
||||
@ -417,7 +424,7 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
with session.Session() as sess:
|
||||
# pylint: disable=invalid-name
|
||||
ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
|
||||
DEFG = collections.namedtuple('DEFG', ['d', 'e', 'f', 'g'])
|
||||
DEFGHI = collections.namedtuple('DEFGHI', ['d', 'e', 'f', 'g', 'h', 'i'])
|
||||
# pylint: enable=invalid-name
|
||||
a_val = 42.0
|
||||
b_val = None
|
||||
@ -425,124 +432,141 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
a = constant_op.constant(a_val)
|
||||
b = control_flow_ops.no_op() # An op, not a tensor.
|
||||
c = constant_op.constant(c_val)
|
||||
# List of lists, tuples, namedtuple, and dict
|
||||
res = sess.run([[a, b, c], (a, b, c),
|
||||
ABC(a=a, b=b, c=c), {
|
||||
'a': a.name,
|
||||
'c': c,
|
||||
'b': b
|
||||
}])
|
||||
self.assertTrue(isinstance(res, list))
|
||||
self.assertEqual(4, len(res))
|
||||
self.assertTrue(isinstance(res[0], list))
|
||||
test_dct = {'a': a.name, 'c': c, 'b': b}
|
||||
test_dct_types = [dict, frozendict, defaultdict]
|
||||
# List of lists, tuples, namedtuple, dict, frozendict, and defaultdict
|
||||
res = sess.run([
|
||||
[a, b, c],
|
||||
(a, b, c),
|
||||
ABC(a=a, b=b, c=c),
|
||||
dict(test_dct),
|
||||
frozendict(test_dct),
|
||||
defaultdict(str, test_dct),
|
||||
])
|
||||
self.assertIsInstance(res, list)
|
||||
self.assertEqual(6, len(res))
|
||||
self.assertIsInstance(res[0], list)
|
||||
self.assertEqual(3, len(res[0]))
|
||||
self.assertEqual(a_val, res[0][0])
|
||||
self.assertEqual(b_val, res[0][1])
|
||||
self.assertEqual(c_val, res[0][2])
|
||||
self.assertTrue(isinstance(res[1], tuple))
|
||||
self.assertIsInstance(res[1], tuple)
|
||||
self.assertEqual(3, len(res[1]))
|
||||
self.assertEqual(a_val, res[1][0])
|
||||
self.assertEqual(b_val, res[1][1])
|
||||
self.assertEqual(c_val, res[1][2])
|
||||
self.assertTrue(isinstance(res[2], ABC))
|
||||
self.assertIsInstance(res[2], ABC)
|
||||
self.assertEqual(a_val, res[2].a)
|
||||
self.assertEqual(b_val, res[2].b)
|
||||
self.assertEqual(c_val, res[2].c)
|
||||
self.assertTrue(isinstance(res[3], dict))
|
||||
self.assertEqual(3, len(res[3]))
|
||||
self.assertEqual(a_val, res[3]['a'])
|
||||
self.assertEqual(b_val, res[3]['b'])
|
||||
self.assertEqual(c_val, res[3]['c'])
|
||||
# Tuple of lists, tuples, namedtuple, and dict
|
||||
res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), {
|
||||
'a': a,
|
||||
'c': c,
|
||||
'b': b
|
||||
}))
|
||||
self.assertTrue(isinstance(res, tuple))
|
||||
self.assertEqual(4, len(res))
|
||||
self.assertTrue(isinstance(res[0], list))
|
||||
for expected_type, r in zip(test_dct_types, res[3:]):
|
||||
self.assertIsInstance(r, expected_type)
|
||||
self.assertEqual(3, len(r))
|
||||
self.assertEqual(a_val, r['a'])
|
||||
self.assertEqual(b_val, r['b'])
|
||||
self.assertEqual(c_val, r['c'])
|
||||
self.assertEqual(res[5].default_factory, str)
|
||||
# Tuple of lists, tuples, namedtuple, dict, frozendict, and defaultdict
|
||||
res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b,
|
||||
c=c), dict(test_dct),
|
||||
frozendict(test_dct), defaultdict(str, test_dct)))
|
||||
self.assertIsInstance(res, tuple)
|
||||
self.assertEqual(6, len(res))
|
||||
self.assertIsInstance(res[0], list)
|
||||
self.assertEqual(3, len(res[0]))
|
||||
self.assertEqual(a_val, res[0][0])
|
||||
self.assertEqual(b_val, res[0][1])
|
||||
self.assertEqual(c_val, res[0][2])
|
||||
self.assertTrue(isinstance(res[1], tuple))
|
||||
self.assertIsInstance(res[1], tuple)
|
||||
self.assertEqual(3, len(res[1]))
|
||||
self.assertEqual(a_val, res[1][0])
|
||||
self.assertEqual(b_val, res[1][1])
|
||||
self.assertEqual(c_val, res[1][2])
|
||||
self.assertTrue(isinstance(res[2], ABC))
|
||||
self.assertIsInstance(res[2], ABC)
|
||||
self.assertEqual(a_val, res[2].a)
|
||||
self.assertEqual(b_val, res[2].b)
|
||||
self.assertEqual(c_val, res[2].c)
|
||||
self.assertTrue(isinstance(res[3], dict))
|
||||
self.assertEqual(3, len(res[3]))
|
||||
self.assertEqual(a_val, res[3]['a'])
|
||||
self.assertEqual(b_val, res[3]['b'])
|
||||
self.assertEqual(c_val, res[3]['c'])
|
||||
# Namedtuple of lists, tuples, namedtuples, and dict
|
||||
for expected_type, r in zip(test_dct_types, res[3:]):
|
||||
self.assertIsInstance(r, expected_type)
|
||||
self.assertEqual(3, len(r))
|
||||
self.assertEqual(a_val, r['a'])
|
||||
self.assertEqual(b_val, r['b'])
|
||||
self.assertEqual(c_val, r['c'])
|
||||
self.assertEqual(res[5].default_factory, str)
|
||||
|
||||
# Namedtuple of lists, tuples, namedtuples, dict, frozendict, defaultdict
|
||||
res = sess.run(
|
||||
DEFG(
|
||||
DEFGHI(
|
||||
d=[a, b, c],
|
||||
e=(a, b, c),
|
||||
f=ABC(a=a.name, b=b, c=c),
|
||||
g={
|
||||
'a': a,
|
||||
'c': c,
|
||||
'b': b
|
||||
}))
|
||||
self.assertTrue(isinstance(res, DEFG))
|
||||
self.assertTrue(isinstance(res.d, list))
|
||||
g=dict(test_dct),
|
||||
h=frozendict(test_dct),
|
||||
i=defaultdict(str, test_dct)))
|
||||
self.assertIsInstance(res, DEFGHI)
|
||||
self.assertIsInstance(res.d, list)
|
||||
self.assertEqual(3, len(res.d))
|
||||
self.assertEqual(a_val, res.d[0])
|
||||
self.assertEqual(b_val, res.d[1])
|
||||
self.assertEqual(c_val, res.d[2])
|
||||
self.assertTrue(isinstance(res.e, tuple))
|
||||
self.assertIsInstance(res.e, tuple)
|
||||
self.assertEqual(3, len(res.e))
|
||||
self.assertEqual(a_val, res.e[0])
|
||||
self.assertEqual(b_val, res.e[1])
|
||||
self.assertEqual(c_val, res.e[2])
|
||||
self.assertTrue(isinstance(res.f, ABC))
|
||||
self.assertIsInstance(res.f, ABC)
|
||||
self.assertEqual(a_val, res.f.a)
|
||||
self.assertEqual(b_val, res.f.b)
|
||||
self.assertEqual(c_val, res.f.c)
|
||||
self.assertTrue(isinstance(res.g, dict))
|
||||
self.assertIsInstance(res.g, dict)
|
||||
self.assertEqual(3, len(res.g))
|
||||
self.assertEqual(a_val, res.g['a'])
|
||||
self.assertEqual(b_val, res.g['b'])
|
||||
self.assertEqual(c_val, res.g['c'])
|
||||
# Dict of lists, tuples, namedtuples, and dict
|
||||
self.assertIsInstance(res.h, frozendict)
|
||||
self.assertEqual(3, len(res.h))
|
||||
self.assertEqual(a_val, res.h['a'])
|
||||
self.assertEqual(b_val, res.h['b'])
|
||||
self.assertEqual(c_val, res.h['c'])
|
||||
self.assertIsInstance(res.i, defaultdict)
|
||||
self.assertEqual(3, len(res.i))
|
||||
self.assertEqual(a_val, res.i['a'])
|
||||
self.assertEqual(b_val, res.i['b'])
|
||||
self.assertEqual(c_val, res.i['c'])
|
||||
self.assertEqual(res.i.default_factory, str)
|
||||
# Dict of lists, tuples, namedtuples, dict, frozendict, defaultdict
|
||||
res = sess.run({
|
||||
'd': [a, b, c],
|
||||
'e': (a, b, c),
|
||||
'f': ABC(a=a, b=b, c=c),
|
||||
'g': {
|
||||
'a': a.name,
|
||||
'c': c,
|
||||
'b': b
|
||||
}
|
||||
'g': dict(test_dct),
|
||||
'h': frozendict(test_dct),
|
||||
'i': defaultdict(str, test_dct),
|
||||
})
|
||||
self.assertTrue(isinstance(res, dict))
|
||||
self.assertEqual(4, len(res))
|
||||
self.assertTrue(isinstance(res['d'], list))
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertEqual(6, len(res))
|
||||
self.assertIsInstance(res['d'], list)
|
||||
self.assertEqual(3, len(res['d']))
|
||||
self.assertEqual(a_val, res['d'][0])
|
||||
self.assertEqual(b_val, res['d'][1])
|
||||
self.assertEqual(c_val, res['d'][2])
|
||||
self.assertTrue(isinstance(res['e'], tuple))
|
||||
self.assertIsInstance(res['e'], tuple)
|
||||
self.assertEqual(3, len(res['e']))
|
||||
self.assertEqual(a_val, res['e'][0])
|
||||
self.assertEqual(b_val, res['e'][1])
|
||||
self.assertEqual(c_val, res['e'][2])
|
||||
self.assertTrue(isinstance(res['f'], ABC))
|
||||
self.assertIsInstance(res['f'], ABC)
|
||||
self.assertEqual(a_val, res['f'].a)
|
||||
self.assertEqual(b_val, res['f'].b)
|
||||
self.assertEqual(c_val, res['f'].c)
|
||||
self.assertTrue(isinstance(res['g'], dict))
|
||||
self.assertEqual(3, len(res['g']))
|
||||
self.assertEqual(a_val, res['g']['a'])
|
||||
self.assertEqual(b_val, res['g']['b'])
|
||||
self.assertEqual(c_val, res['g']['c'])
|
||||
for expected_type, r_key in zip(test_dct_types, ('g', 'h', 'i')):
|
||||
r = res[r_key]
|
||||
self.assertIsInstance(r, expected_type)
|
||||
self.assertEqual(3, len(r))
|
||||
self.assertEqual(a_val, r['a'])
|
||||
self.assertEqual(b_val, r['b'])
|
||||
self.assertEqual(c_val, r['c'])
|
||||
self.assertEqual(res['i'].default_factory, str)
|
||||
|
||||
def testFetchTensorObject(self):
|
||||
with session.Session() as s:
|
||||
@ -1279,7 +1303,7 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
def testNotEntered(self):
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(ops._default_session_stack.get_default(), None)
|
||||
self.assertIsNone(ops._default_session_stack.get_default())
|
||||
# pylint: enable=protected-access
|
||||
with ops.device('/cpu:0'):
|
||||
sess = session.Session()
|
||||
@ -1326,10 +1350,10 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
sess2 = session.InteractiveSession()
|
||||
self.assertEqual(1, len(w))
|
||||
self.assertTrue('An interactive session is already active. This can cause '
|
||||
'out-of-memory errors in some cases. You must explicitly '
|
||||
'call `InteractiveSession.close()` to release resources '
|
||||
'held by the other session(s).' in str(w[0].message))
|
||||
self.assertIn('An interactive session is already active. This can cause '
|
||||
'out-of-memory errors in some cases. You must explicitly '
|
||||
'call `InteractiveSession.close()` to release resources '
|
||||
'held by the other session(s).', str(w[0].message))
|
||||
sess2.close()
|
||||
sess.close()
|
||||
|
||||
@ -1610,10 +1634,10 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
e = constant_op.constant(44.0, name=b'e')
|
||||
f = constant_op.constant(45.0, name=r'f')
|
||||
|
||||
self.assertTrue(isinstance(c.name, six.text_type))
|
||||
self.assertTrue(isinstance(d.name, six.text_type))
|
||||
self.assertTrue(isinstance(e.name, six.text_type))
|
||||
self.assertTrue(isinstance(f.name, six.text_type))
|
||||
self.assertIsInstance(c.name, six.text_type)
|
||||
self.assertIsInstance(d.name, six.text_type)
|
||||
self.assertIsInstance(e.name, six.text_type)
|
||||
self.assertIsInstance(f.name, six.text_type)
|
||||
|
||||
self.assertEqual(42.0, sess.run('c:0'))
|
||||
self.assertEqual(42.0, sess.run(u'c:0'))
|
||||
@ -1673,10 +1697,10 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
with ops.device('/cpu:0'):
|
||||
with session.Session() as sess:
|
||||
sess.run(constant_op.constant(1.0))
|
||||
self.assertTrue(not run_metadata.HasField('step_stats'))
|
||||
self.assertFalse(run_metadata.HasField('step_stats'))
|
||||
|
||||
sess.run(constant_op.constant(1.0), run_metadata=run_metadata)
|
||||
self.assertTrue(not run_metadata.HasField('step_stats'))
|
||||
self.assertFalse(run_metadata.HasField('step_stats'))
|
||||
|
||||
sess.run(
|
||||
constant_op.constant(1.0),
|
||||
@ -1697,11 +1721,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
sess.run(constant_op.constant(1.0), options=None, run_metadata=None)
|
||||
sess.run(
|
||||
constant_op.constant(1.0), options=None, run_metadata=run_metadata)
|
||||
self.assertTrue(not run_metadata.HasField('step_stats'))
|
||||
self.assertFalse(run_metadata.HasField('step_stats'))
|
||||
|
||||
sess.run(
|
||||
constant_op.constant(1.0), options=run_options, run_metadata=None)
|
||||
self.assertTrue(not run_metadata.HasField('step_stats'))
|
||||
self.assertFalse(run_metadata.HasField('step_stats'))
|
||||
|
||||
sess.run(
|
||||
constant_op.constant(1.0),
|
||||
@ -1730,9 +1754,9 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
with ops.Graph().as_default(), ops.device('/cpu:0'):
|
||||
a = constant_op.constant([[1, 2]])
|
||||
sess = session.Session()
|
||||
self.assertFalse('_output_shapes' in sess.graph_def.node[0].attr)
|
||||
self.assertNotIn('_output_shapes', sess.graph_def.node[0].attr)
|
||||
# Avoid lint error regarding 'unused' var a.
|
||||
self.assertTrue(a == a)
|
||||
self.assertEqual(a, a)
|
||||
|
||||
def testInferShapesTrue(self):
|
||||
config_pb = config_pb2.ConfigProto(
|
||||
@ -1740,9 +1764,9 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
with ops.Graph().as_default(), ops.device('/cpu:0'):
|
||||
a = constant_op.constant([[1, 2]])
|
||||
sess = session.Session(config=config_pb)
|
||||
self.assertTrue('_output_shapes' in sess.graph_def.node[0].attr)
|
||||
self.assertIn('_output_shapes', sess.graph_def.node[0].attr)
|
||||
# Avoid lint error regarding 'unused' var a.
|
||||
self.assertTrue(a == a)
|
||||
self.assertEqual(a, a)
|
||||
|
||||
def testBuildCostModel(self):
|
||||
run_options = config_pb2.RunOptions()
|
||||
|
Loading…
x
Reference in New Issue
Block a user