Pass all the arguments at once in the generator to fetch_type() in _DictFetchMapper and ensure the same default_factory for collections.defaultdict type.

PiperOrigin-RevId: 330000469
Change-Id: Id09a5f79528164b18f5b78a31ee1e13c3ca1518d
This commit is contained in:
A. Unique TensorFlower 2020-09-03 15:01:28 -07:00 committed by TensorFlower Gardener
parent 2cea1ba130
commit 8320d33533
2 changed files with 142 additions and 108 deletions

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import re
import threading
@ -41,13 +42,14 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.experimental import mixed_precision_global_state
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import tf_export
_python_session_create_counter = monitoring.Counter(
'/tensorflow/api/python/session_create_counter',
'Counter for number of sessions created in Python.')
class SessionInterface(object):
"""Base class for implementations of TensorFlow client sessions."""
@ -406,6 +408,12 @@ class _DictFetchMapper(_FetchMapper):
fetches: Dict of fetches.
"""
self._fetch_type = type(fetches)
if isinstance(fetches, collections.defaultdict):
self._type_ctor = functools.partial(collections.defaultdict,
fetches.default_factory)
else:
self._type_ctor = self._fetch_type
self._keys = fetches.keys()
self._mappers = [
_FetchMapper.for_fetch(fetch) for fetch in fetches.values()
@ -416,10 +424,12 @@ class _DictFetchMapper(_FetchMapper):
return self._unique_fetches
def build_results(self, values):
results = self._fetch_type()
for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
results[k] = m.build_results([values[j] for j in vi])
return results
def _generator():
for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
yield k, m.build_results([values[j] for j in vi])
return self._type_ctor(_generator())
class _AttrsFetchMapper(_FetchMapper):

View File

@ -18,8 +18,8 @@ from __future__ import division
from __future__ import print_function
import collections
import random
import os
import random
import sys
import threading
import time
@ -68,6 +68,13 @@ try:
except ImportError:
attr = None
try:
from frozendict import frozendict # pylint:disable=g-import-not-at-top
except ImportError:
frozendict = dict # pylint:disable=invalid-name
defaultdict = collections.defaultdict # pylint:disable=invalid-name
class SessionTest(test_util.TensorFlowTestCase):
@ -222,13 +229,13 @@ class SessionTest(test_util.TensorFlowTestCase):
res = sess.run(a)
self.assertEqual(42.0, res)
res = sess.run(a.op) # An op, not a tensor.
self.assertEqual(None, res)
self.assertIsNone(res)
tensor_runner = sess.make_callable(a)
res = tensor_runner()
self.assertEqual(42.0, res)
op_runner = sess.make_callable(a.op)
res = op_runner()
self.assertEqual(None, res)
self.assertIsNone(res)
def testFetchSingletonByName(self):
with session.Session() as sess:
@ -236,7 +243,7 @@ class SessionTest(test_util.TensorFlowTestCase):
res = sess.run(a.name)
self.assertEqual(42.0, res)
res = sess.run(a.op) # An op, not a tensor.
self.assertEqual(None, res)
self.assertIsNone(res)
def testFetchList(self):
with session.Session() as sess:
@ -246,11 +253,11 @@ class SessionTest(test_util.TensorFlowTestCase):
v = variables.Variable([54.0])
assign = v.assign([63.0])
res = sess.run([a, b, c, a.name, assign.op])
self.assertTrue(isinstance(res, list))
self.assertIsInstance(res, list)
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
list_runner = sess.make_callable([a, b, c, a.name, assign.op])
res = list_runner()
self.assertTrue(isinstance(res, list))
self.assertIsInstance(res, list)
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
def testFetchTuple(self):
@ -259,11 +266,11 @@ class SessionTest(test_util.TensorFlowTestCase):
b = control_flow_ops.no_op() # An op, not a tensor.
c = constant_op.constant(44.0)
res = sess.run((a, b, c, a.name))
self.assertTrue(isinstance(res, tuple))
self.assertIsInstance(res, tuple)
self.assertEqual((42.0, None, 44.0, 42.0), res)
tuple_runner = sess.make_callable((a, b, c, a.name))
res = tuple_runner()
self.assertTrue(isinstance(res, tuple))
self.assertIsInstance(res, tuple)
self.assertEqual((42.0, None, 44.0, 42.0), res)
def testFetchNamedTuple(self):
@ -275,15 +282,15 @@ class SessionTest(test_util.TensorFlowTestCase):
b = control_flow_ops.no_op() # An op, not a tensor.
c = constant_op.constant(44.0)
res = sess.run(ABC(a, b, c))
self.assertTrue(isinstance(res, ABC))
self.assertIsInstance(res, ABC)
self.assertEqual(42.0, res.a)
self.assertEqual(None, res.b)
self.assertIsNone(res.b)
self.assertEqual(44.0, res.c)
namedtuple_runner = sess.make_callable(ABC(a, b, c))
res = namedtuple_runner()
self.assertTrue(isinstance(res, ABC))
self.assertIsInstance(res, ABC)
self.assertEqual(42.0, res.a)
self.assertEqual(None, res.b)
self.assertIsNone(res.b)
self.assertEqual(44.0, res.c)
def testFetchDict(self):
@ -292,9 +299,9 @@ class SessionTest(test_util.TensorFlowTestCase):
b = control_flow_ops.no_op() # An op, not a tensor.
c = constant_op.constant(44.0)
res = sess.run({'a': a, 'b': b, 'c': c})
self.assertTrue(isinstance(res, dict))
self.assertIsInstance(res, dict)
self.assertEqual(42.0, res['a'])
self.assertEqual(None, res['b'])
self.assertIsNone(res['b'])
self.assertEqual(44.0, res['c'])
def testFetchOrderedDict(self):
@ -303,10 +310,10 @@ class SessionTest(test_util.TensorFlowTestCase):
b = control_flow_ops.no_op() # An op, not a tensor.
c = constant_op.constant(44.0)
res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)]))
self.assertTrue(isinstance(res, collections.OrderedDict))
self.assertIsInstance(res, collections.OrderedDict)
self.assertEqual([3, 2, 1], list(res.keys()))
self.assertEqual(42.0, res[3])
self.assertEqual(None, res[2])
self.assertIsNone(res[2])
self.assertEqual(44.0, res[1])
@test_util.run_v1_only('b/120545219')
@ -393,23 +400,23 @@ class SessionTest(test_util.TensorFlowTestCase):
a = constant_op.constant(a_val)
res = sess.run([[], tuple(), {}])
self.assertTrue(isinstance(res, list))
self.assertIsInstance(res, list)
self.assertEqual(3, len(res))
self.assertTrue(isinstance(res[0], list))
self.assertIsInstance(res[0], list)
self.assertEqual(0, len(res[0]))
self.assertTrue(isinstance(res[1], tuple))
self.assertIsInstance(res[1], tuple)
self.assertEqual(0, len(res[1]))
self.assertTrue(isinstance(res[2], dict))
self.assertIsInstance(res[2], dict)
self.assertEqual(0, len(res[2]))
res = sess.run([[], tuple(), {}, a])
self.assertTrue(isinstance(res, list))
self.assertIsInstance(res, list)
self.assertEqual(4, len(res))
self.assertTrue(isinstance(res[0], list))
self.assertIsInstance(res[0], list)
self.assertEqual(0, len(res[0]))
self.assertTrue(isinstance(res[1], tuple))
self.assertIsInstance(res[1], tuple)
self.assertEqual(0, len(res[1]))
self.assertTrue(isinstance(res[2], dict))
self.assertIsInstance(res[2], dict)
self.assertEqual(0, len(res[2]))
self.assertEqual(a_val, res[3])
@ -417,7 +424,7 @@ class SessionTest(test_util.TensorFlowTestCase):
with session.Session() as sess:
# pylint: disable=invalid-name
ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
DEFG = collections.namedtuple('DEFG', ['d', 'e', 'f', 'g'])
DEFGHI = collections.namedtuple('DEFGHI', ['d', 'e', 'f', 'g', 'h', 'i'])
# pylint: enable=invalid-name
a_val = 42.0
b_val = None
@ -425,124 +432,141 @@ class SessionTest(test_util.TensorFlowTestCase):
a = constant_op.constant(a_val)
b = control_flow_ops.no_op() # An op, not a tensor.
c = constant_op.constant(c_val)
# List of lists, tuples, namedtuple, and dict
res = sess.run([[a, b, c], (a, b, c),
ABC(a=a, b=b, c=c), {
'a': a.name,
'c': c,
'b': b
}])
self.assertTrue(isinstance(res, list))
self.assertEqual(4, len(res))
self.assertTrue(isinstance(res[0], list))
test_dct = {'a': a.name, 'c': c, 'b': b}
test_dct_types = [dict, frozendict, defaultdict]
# List of lists, tuples, namedtuple, dict, frozendict, and defaultdict
res = sess.run([
[a, b, c],
(a, b, c),
ABC(a=a, b=b, c=c),
dict(test_dct),
frozendict(test_dct),
defaultdict(str, test_dct),
])
self.assertIsInstance(res, list)
self.assertEqual(6, len(res))
self.assertIsInstance(res[0], list)
self.assertEqual(3, len(res[0]))
self.assertEqual(a_val, res[0][0])
self.assertEqual(b_val, res[0][1])
self.assertEqual(c_val, res[0][2])
self.assertTrue(isinstance(res[1], tuple))
self.assertIsInstance(res[1], tuple)
self.assertEqual(3, len(res[1]))
self.assertEqual(a_val, res[1][0])
self.assertEqual(b_val, res[1][1])
self.assertEqual(c_val, res[1][2])
self.assertTrue(isinstance(res[2], ABC))
self.assertIsInstance(res[2], ABC)
self.assertEqual(a_val, res[2].a)
self.assertEqual(b_val, res[2].b)
self.assertEqual(c_val, res[2].c)
self.assertTrue(isinstance(res[3], dict))
self.assertEqual(3, len(res[3]))
self.assertEqual(a_val, res[3]['a'])
self.assertEqual(b_val, res[3]['b'])
self.assertEqual(c_val, res[3]['c'])
# Tuple of lists, tuples, namedtuple, and dict
res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), {
'a': a,
'c': c,
'b': b
}))
self.assertTrue(isinstance(res, tuple))
self.assertEqual(4, len(res))
self.assertTrue(isinstance(res[0], list))
for expected_type, r in zip(test_dct_types, res[3:]):
self.assertIsInstance(r, expected_type)
self.assertEqual(3, len(r))
self.assertEqual(a_val, r['a'])
self.assertEqual(b_val, r['b'])
self.assertEqual(c_val, r['c'])
self.assertEqual(res[5].default_factory, str)
# Tuple of lists, tuples, namedtuple, dict, frozendict, and defaultdict
res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b,
c=c), dict(test_dct),
frozendict(test_dct), defaultdict(str, test_dct)))
self.assertIsInstance(res, tuple)
self.assertEqual(6, len(res))
self.assertIsInstance(res[0], list)
self.assertEqual(3, len(res[0]))
self.assertEqual(a_val, res[0][0])
self.assertEqual(b_val, res[0][1])
self.assertEqual(c_val, res[0][2])
self.assertTrue(isinstance(res[1], tuple))
self.assertIsInstance(res[1], tuple)
self.assertEqual(3, len(res[1]))
self.assertEqual(a_val, res[1][0])
self.assertEqual(b_val, res[1][1])
self.assertEqual(c_val, res[1][2])
self.assertTrue(isinstance(res[2], ABC))
self.assertIsInstance(res[2], ABC)
self.assertEqual(a_val, res[2].a)
self.assertEqual(b_val, res[2].b)
self.assertEqual(c_val, res[2].c)
self.assertTrue(isinstance(res[3], dict))
self.assertEqual(3, len(res[3]))
self.assertEqual(a_val, res[3]['a'])
self.assertEqual(b_val, res[3]['b'])
self.assertEqual(c_val, res[3]['c'])
# Namedtuple of lists, tuples, namedtuples, and dict
for expected_type, r in zip(test_dct_types, res[3:]):
self.assertIsInstance(r, expected_type)
self.assertEqual(3, len(r))
self.assertEqual(a_val, r['a'])
self.assertEqual(b_val, r['b'])
self.assertEqual(c_val, r['c'])
self.assertEqual(res[5].default_factory, str)
# Namedtuple of lists, tuples, namedtuples, dict, frozendict, defaultdict
res = sess.run(
DEFG(
DEFGHI(
d=[a, b, c],
e=(a, b, c),
f=ABC(a=a.name, b=b, c=c),
g={
'a': a,
'c': c,
'b': b
}))
self.assertTrue(isinstance(res, DEFG))
self.assertTrue(isinstance(res.d, list))
g=dict(test_dct),
h=frozendict(test_dct),
i=defaultdict(str, test_dct)))
self.assertIsInstance(res, DEFGHI)
self.assertIsInstance(res.d, list)
self.assertEqual(3, len(res.d))
self.assertEqual(a_val, res.d[0])
self.assertEqual(b_val, res.d[1])
self.assertEqual(c_val, res.d[2])
self.assertTrue(isinstance(res.e, tuple))
self.assertIsInstance(res.e, tuple)
self.assertEqual(3, len(res.e))
self.assertEqual(a_val, res.e[0])
self.assertEqual(b_val, res.e[1])
self.assertEqual(c_val, res.e[2])
self.assertTrue(isinstance(res.f, ABC))
self.assertIsInstance(res.f, ABC)
self.assertEqual(a_val, res.f.a)
self.assertEqual(b_val, res.f.b)
self.assertEqual(c_val, res.f.c)
self.assertTrue(isinstance(res.g, dict))
self.assertIsInstance(res.g, dict)
self.assertEqual(3, len(res.g))
self.assertEqual(a_val, res.g['a'])
self.assertEqual(b_val, res.g['b'])
self.assertEqual(c_val, res.g['c'])
# Dict of lists, tuples, namedtuples, and dict
self.assertIsInstance(res.h, frozendict)
self.assertEqual(3, len(res.h))
self.assertEqual(a_val, res.h['a'])
self.assertEqual(b_val, res.h['b'])
self.assertEqual(c_val, res.h['c'])
self.assertIsInstance(res.i, defaultdict)
self.assertEqual(3, len(res.i))
self.assertEqual(a_val, res.i['a'])
self.assertEqual(b_val, res.i['b'])
self.assertEqual(c_val, res.i['c'])
self.assertEqual(res.i.default_factory, str)
# Dict of lists, tuples, namedtuples, dict, frozendict, defaultdict
res = sess.run({
'd': [a, b, c],
'e': (a, b, c),
'f': ABC(a=a, b=b, c=c),
'g': {
'a': a.name,
'c': c,
'b': b
}
'g': dict(test_dct),
'h': frozendict(test_dct),
'i': defaultdict(str, test_dct),
})
self.assertTrue(isinstance(res, dict))
self.assertEqual(4, len(res))
self.assertTrue(isinstance(res['d'], list))
self.assertIsInstance(res, dict)
self.assertEqual(6, len(res))
self.assertIsInstance(res['d'], list)
self.assertEqual(3, len(res['d']))
self.assertEqual(a_val, res['d'][0])
self.assertEqual(b_val, res['d'][1])
self.assertEqual(c_val, res['d'][2])
self.assertTrue(isinstance(res['e'], tuple))
self.assertIsInstance(res['e'], tuple)
self.assertEqual(3, len(res['e']))
self.assertEqual(a_val, res['e'][0])
self.assertEqual(b_val, res['e'][1])
self.assertEqual(c_val, res['e'][2])
self.assertTrue(isinstance(res['f'], ABC))
self.assertIsInstance(res['f'], ABC)
self.assertEqual(a_val, res['f'].a)
self.assertEqual(b_val, res['f'].b)
self.assertEqual(c_val, res['f'].c)
self.assertTrue(isinstance(res['g'], dict))
self.assertEqual(3, len(res['g']))
self.assertEqual(a_val, res['g']['a'])
self.assertEqual(b_val, res['g']['b'])
self.assertEqual(c_val, res['g']['c'])
for expected_type, r_key in zip(test_dct_types, ('g', 'h', 'i')):
r = res[r_key]
self.assertIsInstance(r, expected_type)
self.assertEqual(3, len(r))
self.assertEqual(a_val, r['a'])
self.assertEqual(b_val, r['b'])
self.assertEqual(c_val, r['c'])
self.assertEqual(res['i'].default_factory, str)
def testFetchTensorObject(self):
with session.Session() as s:
@ -1279,7 +1303,7 @@ class SessionTest(test_util.TensorFlowTestCase):
@test_util.run_v1_only('b/120545219')
def testNotEntered(self):
# pylint: disable=protected-access
self.assertEqual(ops._default_session_stack.get_default(), None)
self.assertIsNone(ops._default_session_stack.get_default())
# pylint: enable=protected-access
with ops.device('/cpu:0'):
sess = session.Session()
@ -1326,10 +1350,10 @@ class SessionTest(test_util.TensorFlowTestCase):
with warnings.catch_warnings(record=True) as w:
sess2 = session.InteractiveSession()
self.assertEqual(1, len(w))
self.assertTrue('An interactive session is already active. This can cause '
'out-of-memory errors in some cases. You must explicitly '
'call `InteractiveSession.close()` to release resources '
'held by the other session(s).' in str(w[0].message))
self.assertIn('An interactive session is already active. This can cause '
'out-of-memory errors in some cases. You must explicitly '
'call `InteractiveSession.close()` to release resources '
'held by the other session(s).', str(w[0].message))
sess2.close()
sess.close()
@ -1610,10 +1634,10 @@ class SessionTest(test_util.TensorFlowTestCase):
e = constant_op.constant(44.0, name=b'e')
f = constant_op.constant(45.0, name=r'f')
self.assertTrue(isinstance(c.name, six.text_type))
self.assertTrue(isinstance(d.name, six.text_type))
self.assertTrue(isinstance(e.name, six.text_type))
self.assertTrue(isinstance(f.name, six.text_type))
self.assertIsInstance(c.name, six.text_type)
self.assertIsInstance(d.name, six.text_type)
self.assertIsInstance(e.name, six.text_type)
self.assertIsInstance(f.name, six.text_type)
self.assertEqual(42.0, sess.run('c:0'))
self.assertEqual(42.0, sess.run(u'c:0'))
@ -1673,10 +1697,10 @@ class SessionTest(test_util.TensorFlowTestCase):
with ops.device('/cpu:0'):
with session.Session() as sess:
sess.run(constant_op.constant(1.0))
self.assertTrue(not run_metadata.HasField('step_stats'))
self.assertFalse(run_metadata.HasField('step_stats'))
sess.run(constant_op.constant(1.0), run_metadata=run_metadata)
self.assertTrue(not run_metadata.HasField('step_stats'))
self.assertFalse(run_metadata.HasField('step_stats'))
sess.run(
constant_op.constant(1.0),
@ -1697,11 +1721,11 @@ class SessionTest(test_util.TensorFlowTestCase):
sess.run(constant_op.constant(1.0), options=None, run_metadata=None)
sess.run(
constant_op.constant(1.0), options=None, run_metadata=run_metadata)
self.assertTrue(not run_metadata.HasField('step_stats'))
self.assertFalse(run_metadata.HasField('step_stats'))
sess.run(
constant_op.constant(1.0), options=run_options, run_metadata=None)
self.assertTrue(not run_metadata.HasField('step_stats'))
self.assertFalse(run_metadata.HasField('step_stats'))
sess.run(
constant_op.constant(1.0),
@ -1730,9 +1754,9 @@ class SessionTest(test_util.TensorFlowTestCase):
with ops.Graph().as_default(), ops.device('/cpu:0'):
a = constant_op.constant([[1, 2]])
sess = session.Session()
self.assertFalse('_output_shapes' in sess.graph_def.node[0].attr)
self.assertNotIn('_output_shapes', sess.graph_def.node[0].attr)
# Avoid lint error regarding 'unused' var a.
self.assertTrue(a == a)
self.assertEqual(a, a)
def testInferShapesTrue(self):
config_pb = config_pb2.ConfigProto(
@ -1740,9 +1764,9 @@ class SessionTest(test_util.TensorFlowTestCase):
with ops.Graph().as_default(), ops.device('/cpu:0'):
a = constant_op.constant([[1, 2]])
sess = session.Session(config=config_pb)
self.assertTrue('_output_shapes' in sess.graph_def.node[0].attr)
self.assertIn('_output_shapes', sess.graph_def.node[0].attr)
# Avoid lint error regarding 'unused' var a.
self.assertTrue(a == a)
self.assertEqual(a, a)
def testBuildCostModel(self):
run_options = config_pb2.RunOptions()