Pass all the arguments at once in the generator to fetch_type() in _DictFetchMapper and ensure the same default_factory for collections.defaultdict type.
PiperOrigin-RevId: 330000469 Change-Id: Id09a5f79528164b18f5b78a31ee1e13c3ca1518d
This commit is contained in:
parent
2cea1ba130
commit
8320d33533
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import functools
|
import functools
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
@ -41,13 +42,14 @@ from tensorflow.python.platform import tf_logging as logging
|
|||||||
from tensorflow.python.training.experimental import mixed_precision_global_state
|
from tensorflow.python.training.experimental import mixed_precision_global_state
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
|
||||||
from tensorflow.python.util.compat import collections_abc
|
from tensorflow.python.util.compat import collections_abc
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
_python_session_create_counter = monitoring.Counter(
|
_python_session_create_counter = monitoring.Counter(
|
||||||
'/tensorflow/api/python/session_create_counter',
|
'/tensorflow/api/python/session_create_counter',
|
||||||
'Counter for number of sessions created in Python.')
|
'Counter for number of sessions created in Python.')
|
||||||
|
|
||||||
|
|
||||||
class SessionInterface(object):
|
class SessionInterface(object):
|
||||||
"""Base class for implementations of TensorFlow client sessions."""
|
"""Base class for implementations of TensorFlow client sessions."""
|
||||||
|
|
||||||
@ -406,6 +408,12 @@ class _DictFetchMapper(_FetchMapper):
|
|||||||
fetches: Dict of fetches.
|
fetches: Dict of fetches.
|
||||||
"""
|
"""
|
||||||
self._fetch_type = type(fetches)
|
self._fetch_type = type(fetches)
|
||||||
|
if isinstance(fetches, collections.defaultdict):
|
||||||
|
self._type_ctor = functools.partial(collections.defaultdict,
|
||||||
|
fetches.default_factory)
|
||||||
|
else:
|
||||||
|
self._type_ctor = self._fetch_type
|
||||||
|
|
||||||
self._keys = fetches.keys()
|
self._keys = fetches.keys()
|
||||||
self._mappers = [
|
self._mappers = [
|
||||||
_FetchMapper.for_fetch(fetch) for fetch in fetches.values()
|
_FetchMapper.for_fetch(fetch) for fetch in fetches.values()
|
||||||
@ -416,10 +424,12 @@ class _DictFetchMapper(_FetchMapper):
|
|||||||
return self._unique_fetches
|
return self._unique_fetches
|
||||||
|
|
||||||
def build_results(self, values):
|
def build_results(self, values):
|
||||||
results = self._fetch_type()
|
|
||||||
for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
|
def _generator():
|
||||||
results[k] = m.build_results([values[j] for j in vi])
|
for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
|
||||||
return results
|
yield k, m.build_results([values[j] for j in vi])
|
||||||
|
|
||||||
|
return self._type_ctor(_generator())
|
||||||
|
|
||||||
|
|
||||||
class _AttrsFetchMapper(_FetchMapper):
|
class _AttrsFetchMapper(_FetchMapper):
|
||||||
|
@ -18,8 +18,8 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import random
|
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@ -68,6 +68,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
attr = None
|
attr = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from frozendict import frozendict # pylint:disable=g-import-not-at-top
|
||||||
|
except ImportError:
|
||||||
|
frozendict = dict # pylint:disable=invalid-name
|
||||||
|
|
||||||
|
defaultdict = collections.defaultdict # pylint:disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class SessionTest(test_util.TensorFlowTestCase):
|
class SessionTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@ -222,13 +229,13 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
res = sess.run(a)
|
res = sess.run(a)
|
||||||
self.assertEqual(42.0, res)
|
self.assertEqual(42.0, res)
|
||||||
res = sess.run(a.op) # An op, not a tensor.
|
res = sess.run(a.op) # An op, not a tensor.
|
||||||
self.assertEqual(None, res)
|
self.assertIsNone(res)
|
||||||
tensor_runner = sess.make_callable(a)
|
tensor_runner = sess.make_callable(a)
|
||||||
res = tensor_runner()
|
res = tensor_runner()
|
||||||
self.assertEqual(42.0, res)
|
self.assertEqual(42.0, res)
|
||||||
op_runner = sess.make_callable(a.op)
|
op_runner = sess.make_callable(a.op)
|
||||||
res = op_runner()
|
res = op_runner()
|
||||||
self.assertEqual(None, res)
|
self.assertIsNone(res)
|
||||||
|
|
||||||
def testFetchSingletonByName(self):
|
def testFetchSingletonByName(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
@ -236,7 +243,7 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
res = sess.run(a.name)
|
res = sess.run(a.name)
|
||||||
self.assertEqual(42.0, res)
|
self.assertEqual(42.0, res)
|
||||||
res = sess.run(a.op) # An op, not a tensor.
|
res = sess.run(a.op) # An op, not a tensor.
|
||||||
self.assertEqual(None, res)
|
self.assertIsNone(res)
|
||||||
|
|
||||||
def testFetchList(self):
|
def testFetchList(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
@ -246,11 +253,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
v = variables.Variable([54.0])
|
v = variables.Variable([54.0])
|
||||||
assign = v.assign([63.0])
|
assign = v.assign([63.0])
|
||||||
res = sess.run([a, b, c, a.name, assign.op])
|
res = sess.run([a, b, c, a.name, assign.op])
|
||||||
self.assertTrue(isinstance(res, list))
|
self.assertIsInstance(res, list)
|
||||||
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
||||||
list_runner = sess.make_callable([a, b, c, a.name, assign.op])
|
list_runner = sess.make_callable([a, b, c, a.name, assign.op])
|
||||||
res = list_runner()
|
res = list_runner()
|
||||||
self.assertTrue(isinstance(res, list))
|
self.assertIsInstance(res, list)
|
||||||
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
||||||
|
|
||||||
def testFetchTuple(self):
|
def testFetchTuple(self):
|
||||||
@ -259,11 +266,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
b = control_flow_ops.no_op() # An op, not a tensor.
|
b = control_flow_ops.no_op() # An op, not a tensor.
|
||||||
c = constant_op.constant(44.0)
|
c = constant_op.constant(44.0)
|
||||||
res = sess.run((a, b, c, a.name))
|
res = sess.run((a, b, c, a.name))
|
||||||
self.assertTrue(isinstance(res, tuple))
|
self.assertIsInstance(res, tuple)
|
||||||
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
||||||
tuple_runner = sess.make_callable((a, b, c, a.name))
|
tuple_runner = sess.make_callable((a, b, c, a.name))
|
||||||
res = tuple_runner()
|
res = tuple_runner()
|
||||||
self.assertTrue(isinstance(res, tuple))
|
self.assertIsInstance(res, tuple)
|
||||||
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
||||||
|
|
||||||
def testFetchNamedTuple(self):
|
def testFetchNamedTuple(self):
|
||||||
@ -275,15 +282,15 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
b = control_flow_ops.no_op() # An op, not a tensor.
|
b = control_flow_ops.no_op() # An op, not a tensor.
|
||||||
c = constant_op.constant(44.0)
|
c = constant_op.constant(44.0)
|
||||||
res = sess.run(ABC(a, b, c))
|
res = sess.run(ABC(a, b, c))
|
||||||
self.assertTrue(isinstance(res, ABC))
|
self.assertIsInstance(res, ABC)
|
||||||
self.assertEqual(42.0, res.a)
|
self.assertEqual(42.0, res.a)
|
||||||
self.assertEqual(None, res.b)
|
self.assertIsNone(res.b)
|
||||||
self.assertEqual(44.0, res.c)
|
self.assertEqual(44.0, res.c)
|
||||||
namedtuple_runner = sess.make_callable(ABC(a, b, c))
|
namedtuple_runner = sess.make_callable(ABC(a, b, c))
|
||||||
res = namedtuple_runner()
|
res = namedtuple_runner()
|
||||||
self.assertTrue(isinstance(res, ABC))
|
self.assertIsInstance(res, ABC)
|
||||||
self.assertEqual(42.0, res.a)
|
self.assertEqual(42.0, res.a)
|
||||||
self.assertEqual(None, res.b)
|
self.assertIsNone(res.b)
|
||||||
self.assertEqual(44.0, res.c)
|
self.assertEqual(44.0, res.c)
|
||||||
|
|
||||||
def testFetchDict(self):
|
def testFetchDict(self):
|
||||||
@ -292,9 +299,9 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
b = control_flow_ops.no_op() # An op, not a tensor.
|
b = control_flow_ops.no_op() # An op, not a tensor.
|
||||||
c = constant_op.constant(44.0)
|
c = constant_op.constant(44.0)
|
||||||
res = sess.run({'a': a, 'b': b, 'c': c})
|
res = sess.run({'a': a, 'b': b, 'c': c})
|
||||||
self.assertTrue(isinstance(res, dict))
|
self.assertIsInstance(res, dict)
|
||||||
self.assertEqual(42.0, res['a'])
|
self.assertEqual(42.0, res['a'])
|
||||||
self.assertEqual(None, res['b'])
|
self.assertIsNone(res['b'])
|
||||||
self.assertEqual(44.0, res['c'])
|
self.assertEqual(44.0, res['c'])
|
||||||
|
|
||||||
def testFetchOrderedDict(self):
|
def testFetchOrderedDict(self):
|
||||||
@ -303,10 +310,10 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
b = control_flow_ops.no_op() # An op, not a tensor.
|
b = control_flow_ops.no_op() # An op, not a tensor.
|
||||||
c = constant_op.constant(44.0)
|
c = constant_op.constant(44.0)
|
||||||
res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)]))
|
res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)]))
|
||||||
self.assertTrue(isinstance(res, collections.OrderedDict))
|
self.assertIsInstance(res, collections.OrderedDict)
|
||||||
self.assertEqual([3, 2, 1], list(res.keys()))
|
self.assertEqual([3, 2, 1], list(res.keys()))
|
||||||
self.assertEqual(42.0, res[3])
|
self.assertEqual(42.0, res[3])
|
||||||
self.assertEqual(None, res[2])
|
self.assertIsNone(res[2])
|
||||||
self.assertEqual(44.0, res[1])
|
self.assertEqual(44.0, res[1])
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
@test_util.run_v1_only('b/120545219')
|
||||||
@ -393,23 +400,23 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
a = constant_op.constant(a_val)
|
a = constant_op.constant(a_val)
|
||||||
|
|
||||||
res = sess.run([[], tuple(), {}])
|
res = sess.run([[], tuple(), {}])
|
||||||
self.assertTrue(isinstance(res, list))
|
self.assertIsInstance(res, list)
|
||||||
self.assertEqual(3, len(res))
|
self.assertEqual(3, len(res))
|
||||||
self.assertTrue(isinstance(res[0], list))
|
self.assertIsInstance(res[0], list)
|
||||||
self.assertEqual(0, len(res[0]))
|
self.assertEqual(0, len(res[0]))
|
||||||
self.assertTrue(isinstance(res[1], tuple))
|
self.assertIsInstance(res[1], tuple)
|
||||||
self.assertEqual(0, len(res[1]))
|
self.assertEqual(0, len(res[1]))
|
||||||
self.assertTrue(isinstance(res[2], dict))
|
self.assertIsInstance(res[2], dict)
|
||||||
self.assertEqual(0, len(res[2]))
|
self.assertEqual(0, len(res[2]))
|
||||||
|
|
||||||
res = sess.run([[], tuple(), {}, a])
|
res = sess.run([[], tuple(), {}, a])
|
||||||
self.assertTrue(isinstance(res, list))
|
self.assertIsInstance(res, list)
|
||||||
self.assertEqual(4, len(res))
|
self.assertEqual(4, len(res))
|
||||||
self.assertTrue(isinstance(res[0], list))
|
self.assertIsInstance(res[0], list)
|
||||||
self.assertEqual(0, len(res[0]))
|
self.assertEqual(0, len(res[0]))
|
||||||
self.assertTrue(isinstance(res[1], tuple))
|
self.assertIsInstance(res[1], tuple)
|
||||||
self.assertEqual(0, len(res[1]))
|
self.assertEqual(0, len(res[1]))
|
||||||
self.assertTrue(isinstance(res[2], dict))
|
self.assertIsInstance(res[2], dict)
|
||||||
self.assertEqual(0, len(res[2]))
|
self.assertEqual(0, len(res[2]))
|
||||||
self.assertEqual(a_val, res[3])
|
self.assertEqual(a_val, res[3])
|
||||||
|
|
||||||
@ -417,7 +424,7 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
|
ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
|
||||||
DEFG = collections.namedtuple('DEFG', ['d', 'e', 'f', 'g'])
|
DEFGHI = collections.namedtuple('DEFGHI', ['d', 'e', 'f', 'g', 'h', 'i'])
|
||||||
# pylint: enable=invalid-name
|
# pylint: enable=invalid-name
|
||||||
a_val = 42.0
|
a_val = 42.0
|
||||||
b_val = None
|
b_val = None
|
||||||
@ -425,124 +432,141 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
a = constant_op.constant(a_val)
|
a = constant_op.constant(a_val)
|
||||||
b = control_flow_ops.no_op() # An op, not a tensor.
|
b = control_flow_ops.no_op() # An op, not a tensor.
|
||||||
c = constant_op.constant(c_val)
|
c = constant_op.constant(c_val)
|
||||||
# List of lists, tuples, namedtuple, and dict
|
test_dct = {'a': a.name, 'c': c, 'b': b}
|
||||||
res = sess.run([[a, b, c], (a, b, c),
|
test_dct_types = [dict, frozendict, defaultdict]
|
||||||
ABC(a=a, b=b, c=c), {
|
# List of lists, tuples, namedtuple, dict, frozendict, and defaultdict
|
||||||
'a': a.name,
|
res = sess.run([
|
||||||
'c': c,
|
[a, b, c],
|
||||||
'b': b
|
(a, b, c),
|
||||||
}])
|
ABC(a=a, b=b, c=c),
|
||||||
self.assertTrue(isinstance(res, list))
|
dict(test_dct),
|
||||||
self.assertEqual(4, len(res))
|
frozendict(test_dct),
|
||||||
self.assertTrue(isinstance(res[0], list))
|
defaultdict(str, test_dct),
|
||||||
|
])
|
||||||
|
self.assertIsInstance(res, list)
|
||||||
|
self.assertEqual(6, len(res))
|
||||||
|
self.assertIsInstance(res[0], list)
|
||||||
self.assertEqual(3, len(res[0]))
|
self.assertEqual(3, len(res[0]))
|
||||||
self.assertEqual(a_val, res[0][0])
|
self.assertEqual(a_val, res[0][0])
|
||||||
self.assertEqual(b_val, res[0][1])
|
self.assertEqual(b_val, res[0][1])
|
||||||
self.assertEqual(c_val, res[0][2])
|
self.assertEqual(c_val, res[0][2])
|
||||||
self.assertTrue(isinstance(res[1], tuple))
|
self.assertIsInstance(res[1], tuple)
|
||||||
self.assertEqual(3, len(res[1]))
|
self.assertEqual(3, len(res[1]))
|
||||||
self.assertEqual(a_val, res[1][0])
|
self.assertEqual(a_val, res[1][0])
|
||||||
self.assertEqual(b_val, res[1][1])
|
self.assertEqual(b_val, res[1][1])
|
||||||
self.assertEqual(c_val, res[1][2])
|
self.assertEqual(c_val, res[1][2])
|
||||||
self.assertTrue(isinstance(res[2], ABC))
|
self.assertIsInstance(res[2], ABC)
|
||||||
self.assertEqual(a_val, res[2].a)
|
self.assertEqual(a_val, res[2].a)
|
||||||
self.assertEqual(b_val, res[2].b)
|
self.assertEqual(b_val, res[2].b)
|
||||||
self.assertEqual(c_val, res[2].c)
|
self.assertEqual(c_val, res[2].c)
|
||||||
self.assertTrue(isinstance(res[3], dict))
|
for expected_type, r in zip(test_dct_types, res[3:]):
|
||||||
self.assertEqual(3, len(res[3]))
|
self.assertIsInstance(r, expected_type)
|
||||||
self.assertEqual(a_val, res[3]['a'])
|
self.assertEqual(3, len(r))
|
||||||
self.assertEqual(b_val, res[3]['b'])
|
self.assertEqual(a_val, r['a'])
|
||||||
self.assertEqual(c_val, res[3]['c'])
|
self.assertEqual(b_val, r['b'])
|
||||||
# Tuple of lists, tuples, namedtuple, and dict
|
self.assertEqual(c_val, r['c'])
|
||||||
res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), {
|
self.assertEqual(res[5].default_factory, str)
|
||||||
'a': a,
|
# Tuple of lists, tuples, namedtuple, dict, frozendict, and defaultdict
|
||||||
'c': c,
|
res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b,
|
||||||
'b': b
|
c=c), dict(test_dct),
|
||||||
}))
|
frozendict(test_dct), defaultdict(str, test_dct)))
|
||||||
self.assertTrue(isinstance(res, tuple))
|
self.assertIsInstance(res, tuple)
|
||||||
self.assertEqual(4, len(res))
|
self.assertEqual(6, len(res))
|
||||||
self.assertTrue(isinstance(res[0], list))
|
self.assertIsInstance(res[0], list)
|
||||||
self.assertEqual(3, len(res[0]))
|
self.assertEqual(3, len(res[0]))
|
||||||
self.assertEqual(a_val, res[0][0])
|
self.assertEqual(a_val, res[0][0])
|
||||||
self.assertEqual(b_val, res[0][1])
|
self.assertEqual(b_val, res[0][1])
|
||||||
self.assertEqual(c_val, res[0][2])
|
self.assertEqual(c_val, res[0][2])
|
||||||
self.assertTrue(isinstance(res[1], tuple))
|
self.assertIsInstance(res[1], tuple)
|
||||||
self.assertEqual(3, len(res[1]))
|
self.assertEqual(3, len(res[1]))
|
||||||
self.assertEqual(a_val, res[1][0])
|
self.assertEqual(a_val, res[1][0])
|
||||||
self.assertEqual(b_val, res[1][1])
|
self.assertEqual(b_val, res[1][1])
|
||||||
self.assertEqual(c_val, res[1][2])
|
self.assertEqual(c_val, res[1][2])
|
||||||
self.assertTrue(isinstance(res[2], ABC))
|
self.assertIsInstance(res[2], ABC)
|
||||||
self.assertEqual(a_val, res[2].a)
|
self.assertEqual(a_val, res[2].a)
|
||||||
self.assertEqual(b_val, res[2].b)
|
self.assertEqual(b_val, res[2].b)
|
||||||
self.assertEqual(c_val, res[2].c)
|
self.assertEqual(c_val, res[2].c)
|
||||||
self.assertTrue(isinstance(res[3], dict))
|
for expected_type, r in zip(test_dct_types, res[3:]):
|
||||||
self.assertEqual(3, len(res[3]))
|
self.assertIsInstance(r, expected_type)
|
||||||
self.assertEqual(a_val, res[3]['a'])
|
self.assertEqual(3, len(r))
|
||||||
self.assertEqual(b_val, res[3]['b'])
|
self.assertEqual(a_val, r['a'])
|
||||||
self.assertEqual(c_val, res[3]['c'])
|
self.assertEqual(b_val, r['b'])
|
||||||
# Namedtuple of lists, tuples, namedtuples, and dict
|
self.assertEqual(c_val, r['c'])
|
||||||
|
self.assertEqual(res[5].default_factory, str)
|
||||||
|
|
||||||
|
# Namedtuple of lists, tuples, namedtuples, dict, frozendict, defaultdict
|
||||||
res = sess.run(
|
res = sess.run(
|
||||||
DEFG(
|
DEFGHI(
|
||||||
d=[a, b, c],
|
d=[a, b, c],
|
||||||
e=(a, b, c),
|
e=(a, b, c),
|
||||||
f=ABC(a=a.name, b=b, c=c),
|
f=ABC(a=a.name, b=b, c=c),
|
||||||
g={
|
g=dict(test_dct),
|
||||||
'a': a,
|
h=frozendict(test_dct),
|
||||||
'c': c,
|
i=defaultdict(str, test_dct)))
|
||||||
'b': b
|
self.assertIsInstance(res, DEFGHI)
|
||||||
}))
|
self.assertIsInstance(res.d, list)
|
||||||
self.assertTrue(isinstance(res, DEFG))
|
|
||||||
self.assertTrue(isinstance(res.d, list))
|
|
||||||
self.assertEqual(3, len(res.d))
|
self.assertEqual(3, len(res.d))
|
||||||
self.assertEqual(a_val, res.d[0])
|
self.assertEqual(a_val, res.d[0])
|
||||||
self.assertEqual(b_val, res.d[1])
|
self.assertEqual(b_val, res.d[1])
|
||||||
self.assertEqual(c_val, res.d[2])
|
self.assertEqual(c_val, res.d[2])
|
||||||
self.assertTrue(isinstance(res.e, tuple))
|
self.assertIsInstance(res.e, tuple)
|
||||||
self.assertEqual(3, len(res.e))
|
self.assertEqual(3, len(res.e))
|
||||||
self.assertEqual(a_val, res.e[0])
|
self.assertEqual(a_val, res.e[0])
|
||||||
self.assertEqual(b_val, res.e[1])
|
self.assertEqual(b_val, res.e[1])
|
||||||
self.assertEqual(c_val, res.e[2])
|
self.assertEqual(c_val, res.e[2])
|
||||||
self.assertTrue(isinstance(res.f, ABC))
|
self.assertIsInstance(res.f, ABC)
|
||||||
self.assertEqual(a_val, res.f.a)
|
self.assertEqual(a_val, res.f.a)
|
||||||
self.assertEqual(b_val, res.f.b)
|
self.assertEqual(b_val, res.f.b)
|
||||||
self.assertEqual(c_val, res.f.c)
|
self.assertEqual(c_val, res.f.c)
|
||||||
self.assertTrue(isinstance(res.g, dict))
|
self.assertIsInstance(res.g, dict)
|
||||||
self.assertEqual(3, len(res.g))
|
self.assertEqual(3, len(res.g))
|
||||||
self.assertEqual(a_val, res.g['a'])
|
self.assertEqual(a_val, res.g['a'])
|
||||||
self.assertEqual(b_val, res.g['b'])
|
self.assertEqual(b_val, res.g['b'])
|
||||||
self.assertEqual(c_val, res.g['c'])
|
self.assertEqual(c_val, res.g['c'])
|
||||||
# Dict of lists, tuples, namedtuples, and dict
|
self.assertIsInstance(res.h, frozendict)
|
||||||
|
self.assertEqual(3, len(res.h))
|
||||||
|
self.assertEqual(a_val, res.h['a'])
|
||||||
|
self.assertEqual(b_val, res.h['b'])
|
||||||
|
self.assertEqual(c_val, res.h['c'])
|
||||||
|
self.assertIsInstance(res.i, defaultdict)
|
||||||
|
self.assertEqual(3, len(res.i))
|
||||||
|
self.assertEqual(a_val, res.i['a'])
|
||||||
|
self.assertEqual(b_val, res.i['b'])
|
||||||
|
self.assertEqual(c_val, res.i['c'])
|
||||||
|
self.assertEqual(res.i.default_factory, str)
|
||||||
|
# Dict of lists, tuples, namedtuples, dict, frozendict, defaultdict
|
||||||
res = sess.run({
|
res = sess.run({
|
||||||
'd': [a, b, c],
|
'd': [a, b, c],
|
||||||
'e': (a, b, c),
|
'e': (a, b, c),
|
||||||
'f': ABC(a=a, b=b, c=c),
|
'f': ABC(a=a, b=b, c=c),
|
||||||
'g': {
|
'g': dict(test_dct),
|
||||||
'a': a.name,
|
'h': frozendict(test_dct),
|
||||||
'c': c,
|
'i': defaultdict(str, test_dct),
|
||||||
'b': b
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
self.assertTrue(isinstance(res, dict))
|
self.assertIsInstance(res, dict)
|
||||||
self.assertEqual(4, len(res))
|
self.assertEqual(6, len(res))
|
||||||
self.assertTrue(isinstance(res['d'], list))
|
self.assertIsInstance(res['d'], list)
|
||||||
self.assertEqual(3, len(res['d']))
|
self.assertEqual(3, len(res['d']))
|
||||||
self.assertEqual(a_val, res['d'][0])
|
self.assertEqual(a_val, res['d'][0])
|
||||||
self.assertEqual(b_val, res['d'][1])
|
self.assertEqual(b_val, res['d'][1])
|
||||||
self.assertEqual(c_val, res['d'][2])
|
self.assertEqual(c_val, res['d'][2])
|
||||||
self.assertTrue(isinstance(res['e'], tuple))
|
self.assertIsInstance(res['e'], tuple)
|
||||||
self.assertEqual(3, len(res['e']))
|
self.assertEqual(3, len(res['e']))
|
||||||
self.assertEqual(a_val, res['e'][0])
|
self.assertEqual(a_val, res['e'][0])
|
||||||
self.assertEqual(b_val, res['e'][1])
|
self.assertEqual(b_val, res['e'][1])
|
||||||
self.assertEqual(c_val, res['e'][2])
|
self.assertEqual(c_val, res['e'][2])
|
||||||
self.assertTrue(isinstance(res['f'], ABC))
|
self.assertIsInstance(res['f'], ABC)
|
||||||
self.assertEqual(a_val, res['f'].a)
|
self.assertEqual(a_val, res['f'].a)
|
||||||
self.assertEqual(b_val, res['f'].b)
|
self.assertEqual(b_val, res['f'].b)
|
||||||
self.assertEqual(c_val, res['f'].c)
|
self.assertEqual(c_val, res['f'].c)
|
||||||
self.assertTrue(isinstance(res['g'], dict))
|
for expected_type, r_key in zip(test_dct_types, ('g', 'h', 'i')):
|
||||||
self.assertEqual(3, len(res['g']))
|
r = res[r_key]
|
||||||
self.assertEqual(a_val, res['g']['a'])
|
self.assertIsInstance(r, expected_type)
|
||||||
self.assertEqual(b_val, res['g']['b'])
|
self.assertEqual(3, len(r))
|
||||||
self.assertEqual(c_val, res['g']['c'])
|
self.assertEqual(a_val, r['a'])
|
||||||
|
self.assertEqual(b_val, r['b'])
|
||||||
|
self.assertEqual(c_val, r['c'])
|
||||||
|
self.assertEqual(res['i'].default_factory, str)
|
||||||
|
|
||||||
def testFetchTensorObject(self):
|
def testFetchTensorObject(self):
|
||||||
with session.Session() as s:
|
with session.Session() as s:
|
||||||
@ -1279,7 +1303,7 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
@test_util.run_v1_only('b/120545219')
|
@test_util.run_v1_only('b/120545219')
|
||||||
def testNotEntered(self):
|
def testNotEntered(self):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
self.assertEqual(ops._default_session_stack.get_default(), None)
|
self.assertIsNone(ops._default_session_stack.get_default())
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
with ops.device('/cpu:0'):
|
with ops.device('/cpu:0'):
|
||||||
sess = session.Session()
|
sess = session.Session()
|
||||||
@ -1326,10 +1350,10 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
sess2 = session.InteractiveSession()
|
sess2 = session.InteractiveSession()
|
||||||
self.assertEqual(1, len(w))
|
self.assertEqual(1, len(w))
|
||||||
self.assertTrue('An interactive session is already active. This can cause '
|
self.assertIn('An interactive session is already active. This can cause '
|
||||||
'out-of-memory errors in some cases. You must explicitly '
|
'out-of-memory errors in some cases. You must explicitly '
|
||||||
'call `InteractiveSession.close()` to release resources '
|
'call `InteractiveSession.close()` to release resources '
|
||||||
'held by the other session(s).' in str(w[0].message))
|
'held by the other session(s).', str(w[0].message))
|
||||||
sess2.close()
|
sess2.close()
|
||||||
sess.close()
|
sess.close()
|
||||||
|
|
||||||
@ -1610,10 +1634,10 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
e = constant_op.constant(44.0, name=b'e')
|
e = constant_op.constant(44.0, name=b'e')
|
||||||
f = constant_op.constant(45.0, name=r'f')
|
f = constant_op.constant(45.0, name=r'f')
|
||||||
|
|
||||||
self.assertTrue(isinstance(c.name, six.text_type))
|
self.assertIsInstance(c.name, six.text_type)
|
||||||
self.assertTrue(isinstance(d.name, six.text_type))
|
self.assertIsInstance(d.name, six.text_type)
|
||||||
self.assertTrue(isinstance(e.name, six.text_type))
|
self.assertIsInstance(e.name, six.text_type)
|
||||||
self.assertTrue(isinstance(f.name, six.text_type))
|
self.assertIsInstance(f.name, six.text_type)
|
||||||
|
|
||||||
self.assertEqual(42.0, sess.run('c:0'))
|
self.assertEqual(42.0, sess.run('c:0'))
|
||||||
self.assertEqual(42.0, sess.run(u'c:0'))
|
self.assertEqual(42.0, sess.run(u'c:0'))
|
||||||
@ -1673,10 +1697,10 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
with ops.device('/cpu:0'):
|
with ops.device('/cpu:0'):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
sess.run(constant_op.constant(1.0))
|
sess.run(constant_op.constant(1.0))
|
||||||
self.assertTrue(not run_metadata.HasField('step_stats'))
|
self.assertFalse(run_metadata.HasField('step_stats'))
|
||||||
|
|
||||||
sess.run(constant_op.constant(1.0), run_metadata=run_metadata)
|
sess.run(constant_op.constant(1.0), run_metadata=run_metadata)
|
||||||
self.assertTrue(not run_metadata.HasField('step_stats'))
|
self.assertFalse(run_metadata.HasField('step_stats'))
|
||||||
|
|
||||||
sess.run(
|
sess.run(
|
||||||
constant_op.constant(1.0),
|
constant_op.constant(1.0),
|
||||||
@ -1697,11 +1721,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
sess.run(constant_op.constant(1.0), options=None, run_metadata=None)
|
sess.run(constant_op.constant(1.0), options=None, run_metadata=None)
|
||||||
sess.run(
|
sess.run(
|
||||||
constant_op.constant(1.0), options=None, run_metadata=run_metadata)
|
constant_op.constant(1.0), options=None, run_metadata=run_metadata)
|
||||||
self.assertTrue(not run_metadata.HasField('step_stats'))
|
self.assertFalse(run_metadata.HasField('step_stats'))
|
||||||
|
|
||||||
sess.run(
|
sess.run(
|
||||||
constant_op.constant(1.0), options=run_options, run_metadata=None)
|
constant_op.constant(1.0), options=run_options, run_metadata=None)
|
||||||
self.assertTrue(not run_metadata.HasField('step_stats'))
|
self.assertFalse(run_metadata.HasField('step_stats'))
|
||||||
|
|
||||||
sess.run(
|
sess.run(
|
||||||
constant_op.constant(1.0),
|
constant_op.constant(1.0),
|
||||||
@ -1730,9 +1754,9 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
with ops.Graph().as_default(), ops.device('/cpu:0'):
|
with ops.Graph().as_default(), ops.device('/cpu:0'):
|
||||||
a = constant_op.constant([[1, 2]])
|
a = constant_op.constant([[1, 2]])
|
||||||
sess = session.Session()
|
sess = session.Session()
|
||||||
self.assertFalse('_output_shapes' in sess.graph_def.node[0].attr)
|
self.assertNotIn('_output_shapes', sess.graph_def.node[0].attr)
|
||||||
# Avoid lint error regarding 'unused' var a.
|
# Avoid lint error regarding 'unused' var a.
|
||||||
self.assertTrue(a == a)
|
self.assertEqual(a, a)
|
||||||
|
|
||||||
def testInferShapesTrue(self):
|
def testInferShapesTrue(self):
|
||||||
config_pb = config_pb2.ConfigProto(
|
config_pb = config_pb2.ConfigProto(
|
||||||
@ -1740,9 +1764,9 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
with ops.Graph().as_default(), ops.device('/cpu:0'):
|
with ops.Graph().as_default(), ops.device('/cpu:0'):
|
||||||
a = constant_op.constant([[1, 2]])
|
a = constant_op.constant([[1, 2]])
|
||||||
sess = session.Session(config=config_pb)
|
sess = session.Session(config=config_pb)
|
||||||
self.assertTrue('_output_shapes' in sess.graph_def.node[0].attr)
|
self.assertIn('_output_shapes', sess.graph_def.node[0].attr)
|
||||||
# Avoid lint error regarding 'unused' var a.
|
# Avoid lint error regarding 'unused' var a.
|
||||||
self.assertTrue(a == a)
|
self.assertEqual(a, a)
|
||||||
|
|
||||||
def testBuildCostModel(self):
|
def testBuildCostModel(self):
|
||||||
run_options = config_pb2.RunOptions()
|
run_options = config_pb2.RunOptions()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user