From 8320d33533425416b1fa50c4325eb240de0f00b7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 Sep 2020 15:01:28 -0700 Subject: [PATCH] Pass all the arguments at once in the generator to fetch_type() in _DictFetchMapper and ensure the same default_factory for collections.defaultdict type. PiperOrigin-RevId: 330000469 Change-Id: Id09a5f79528164b18f5b78a31ee1e13c3ca1518d --- tensorflow/python/client/session.py | 20 +- tensorflow/python/client/session_test.py | 230 +++++++++++++---------- 2 files changed, 142 insertions(+), 108 deletions(-) diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index bcd27fb6318..5a83f5776a1 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import functools import re import threading @@ -41,13 +42,14 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.experimental import mixed_precision_global_state from tensorflow.python.util import compat from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.compat import collections_abc +from tensorflow.python.util.tf_export import tf_export _python_session_create_counter = monitoring.Counter( '/tensorflow/api/python/session_create_counter', 'Counter for number of sessions created in Python.') + class SessionInterface(object): """Base class for implementations of TensorFlow client sessions.""" @@ -406,6 +408,12 @@ class _DictFetchMapper(_FetchMapper): fetches: Dict of fetches. """ self._fetch_type = type(fetches) + if isinstance(fetches, collections.defaultdict): + self._type_ctor = functools.partial(collections.defaultdict, + fetches.default_factory) + else: + self._type_ctor = self._fetch_type + self._keys = fetches.keys() self._mappers = [ _FetchMapper.for_fetch(fetch) for fetch in fetches.values() @@ -416,10 +424,12 @@ class _DictFetchMapper(_FetchMapper): return self._unique_fetches def build_results(self, values): - results = self._fetch_type() - for k, m, vi in zip(self._keys, self._mappers, self._value_indices): - results[k] = m.build_results([values[j] for j in vi]) - return results + + def _generator(): + for k, m, vi in zip(self._keys, self._mappers, self._value_indices): + yield k, m.build_results([values[j] for j in vi]) + + return self._type_ctor(_generator()) class _AttrsFetchMapper(_FetchMapper): diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 23d5ddaee44..4bf5095ae8b 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -18,8 +18,8 @@ from __future__ import division from __future__ import print_function import collections -import random import os +import random import sys import threading import time @@ -68,6 +68,13 @@ try: except ImportError: attr = None +try: + from frozendict import frozendict # pylint:disable=g-import-not-at-top +except ImportError: + frozendict = dict # pylint:disable=invalid-name + +defaultdict = collections.defaultdict # pylint:disable=invalid-name + class SessionTest(test_util.TensorFlowTestCase): @@ -222,13 +229,13 @@ class SessionTest(test_util.TensorFlowTestCase): res = sess.run(a) self.assertEqual(42.0, res) res = sess.run(a.op) # An op, not a tensor. - self.assertEqual(None, res) + self.assertIsNone(res) tensor_runner = sess.make_callable(a) res = tensor_runner() self.assertEqual(42.0, res) op_runner = sess.make_callable(a.op) res = op_runner() - self.assertEqual(None, res) + self.assertIsNone(res) def testFetchSingletonByName(self): with session.Session() as sess: @@ -236,7 +243,7 @@ class SessionTest(test_util.TensorFlowTestCase): res = sess.run(a.name) self.assertEqual(42.0, res) res = sess.run(a.op) # An op, not a tensor. - self.assertEqual(None, res) + self.assertIsNone(res) def testFetchList(self): with session.Session() as sess: @@ -246,11 +253,11 @@ class SessionTest(test_util.TensorFlowTestCase): v = variables.Variable([54.0]) assign = v.assign([63.0]) res = sess.run([a, b, c, a.name, assign.op]) - self.assertTrue(isinstance(res, list)) + self.assertIsInstance(res, list) self.assertEqual([42.0, None, 44.0, 42.0, None], res) list_runner = sess.make_callable([a, b, c, a.name, assign.op]) res = list_runner() - self.assertTrue(isinstance(res, list)) + self.assertIsInstance(res, list) self.assertEqual([42.0, None, 44.0, 42.0, None], res) def testFetchTuple(self): @@ -259,11 +266,11 @@ class SessionTest(test_util.TensorFlowTestCase): b = control_flow_ops.no_op() # An op, not a tensor. c = constant_op.constant(44.0) res = sess.run((a, b, c, a.name)) - self.assertTrue(isinstance(res, tuple)) + self.assertIsInstance(res, tuple) self.assertEqual((42.0, None, 44.0, 42.0), res) tuple_runner = sess.make_callable((a, b, c, a.name)) res = tuple_runner() - self.assertTrue(isinstance(res, tuple)) + self.assertIsInstance(res, tuple) self.assertEqual((42.0, None, 44.0, 42.0), res) def testFetchNamedTuple(self): @@ -275,15 +282,15 @@ class SessionTest(test_util.TensorFlowTestCase): b = control_flow_ops.no_op() # An op, not a tensor. c = constant_op.constant(44.0) res = sess.run(ABC(a, b, c)) - self.assertTrue(isinstance(res, ABC)) + self.assertIsInstance(res, ABC) self.assertEqual(42.0, res.a) - self.assertEqual(None, res.b) + self.assertIsNone(res.b) self.assertEqual(44.0, res.c) namedtuple_runner = sess.make_callable(ABC(a, b, c)) res = namedtuple_runner() - self.assertTrue(isinstance(res, ABC)) + self.assertIsInstance(res, ABC) self.assertEqual(42.0, res.a) - self.assertEqual(None, res.b) + self.assertIsNone(res.b) self.assertEqual(44.0, res.c) def testFetchDict(self): @@ -292,9 +299,9 @@ class SessionTest(test_util.TensorFlowTestCase): b = control_flow_ops.no_op() # An op, not a tensor. c = constant_op.constant(44.0) res = sess.run({'a': a, 'b': b, 'c': c}) - self.assertTrue(isinstance(res, dict)) + self.assertIsInstance(res, dict) self.assertEqual(42.0, res['a']) - self.assertEqual(None, res['b']) + self.assertIsNone(res['b']) self.assertEqual(44.0, res['c']) def testFetchOrderedDict(self): @@ -303,10 +310,10 @@ class SessionTest(test_util.TensorFlowTestCase): b = control_flow_ops.no_op() # An op, not a tensor. c = constant_op.constant(44.0) res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)])) - self.assertTrue(isinstance(res, collections.OrderedDict)) + self.assertIsInstance(res, collections.OrderedDict) self.assertEqual([3, 2, 1], list(res.keys())) self.assertEqual(42.0, res[3]) - self.assertEqual(None, res[2]) + self.assertIsNone(res[2]) self.assertEqual(44.0, res[1]) @test_util.run_v1_only('b/120545219') @@ -393,23 +400,23 @@ class SessionTest(test_util.TensorFlowTestCase): a = constant_op.constant(a_val) res = sess.run([[], tuple(), {}]) - self.assertTrue(isinstance(res, list)) + self.assertIsInstance(res, list) self.assertEqual(3, len(res)) - self.assertTrue(isinstance(res[0], list)) + self.assertIsInstance(res[0], list) self.assertEqual(0, len(res[0])) - self.assertTrue(isinstance(res[1], tuple)) + self.assertIsInstance(res[1], tuple) self.assertEqual(0, len(res[1])) - self.assertTrue(isinstance(res[2], dict)) + self.assertIsInstance(res[2], dict) self.assertEqual(0, len(res[2])) res = sess.run([[], tuple(), {}, a]) - self.assertTrue(isinstance(res, list)) + self.assertIsInstance(res, list) self.assertEqual(4, len(res)) - self.assertTrue(isinstance(res[0], list)) + self.assertIsInstance(res[0], list) self.assertEqual(0, len(res[0])) - self.assertTrue(isinstance(res[1], tuple)) + self.assertIsInstance(res[1], tuple) self.assertEqual(0, len(res[1])) - self.assertTrue(isinstance(res[2], dict)) + self.assertIsInstance(res[2], dict) self.assertEqual(0, len(res[2])) self.assertEqual(a_val, res[3]) @@ -417,7 +424,7 @@ class SessionTest(test_util.TensorFlowTestCase): with session.Session() as sess: # pylint: disable=invalid-name ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) - DEFG = collections.namedtuple('DEFG', ['d', 'e', 'f', 'g']) + DEFGHI = collections.namedtuple('DEFGHI', ['d', 'e', 'f', 'g', 'h', 'i']) # pylint: enable=invalid-name a_val = 42.0 b_val = None @@ -425,124 +432,141 @@ class SessionTest(test_util.TensorFlowTestCase): a = constant_op.constant(a_val) b = control_flow_ops.no_op() # An op, not a tensor. c = constant_op.constant(c_val) - # List of lists, tuples, namedtuple, and dict - res = sess.run([[a, b, c], (a, b, c), - ABC(a=a, b=b, c=c), { - 'a': a.name, - 'c': c, - 'b': b - }]) - self.assertTrue(isinstance(res, list)) - self.assertEqual(4, len(res)) - self.assertTrue(isinstance(res[0], list)) + test_dct = {'a': a.name, 'c': c, 'b': b} + test_dct_types = [dict, frozendict, defaultdict] + # List of lists, tuples, namedtuple, dict, frozendict, and defaultdict + res = sess.run([ + [a, b, c], + (a, b, c), + ABC(a=a, b=b, c=c), + dict(test_dct), + frozendict(test_dct), + defaultdict(str, test_dct), + ]) + self.assertIsInstance(res, list) + self.assertEqual(6, len(res)) + self.assertIsInstance(res[0], list) self.assertEqual(3, len(res[0])) self.assertEqual(a_val, res[0][0]) self.assertEqual(b_val, res[0][1]) self.assertEqual(c_val, res[0][2]) - self.assertTrue(isinstance(res[1], tuple)) + self.assertIsInstance(res[1], tuple) self.assertEqual(3, len(res[1])) self.assertEqual(a_val, res[1][0]) self.assertEqual(b_val, res[1][1]) self.assertEqual(c_val, res[1][2]) - self.assertTrue(isinstance(res[2], ABC)) + self.assertIsInstance(res[2], ABC) self.assertEqual(a_val, res[2].a) self.assertEqual(b_val, res[2].b) self.assertEqual(c_val, res[2].c) - self.assertTrue(isinstance(res[3], dict)) - self.assertEqual(3, len(res[3])) - self.assertEqual(a_val, res[3]['a']) - self.assertEqual(b_val, res[3]['b']) - self.assertEqual(c_val, res[3]['c']) - # Tuple of lists, tuples, namedtuple, and dict - res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), { - 'a': a, - 'c': c, - 'b': b - })) - self.assertTrue(isinstance(res, tuple)) - self.assertEqual(4, len(res)) - self.assertTrue(isinstance(res[0], list)) + for expected_type, r in zip(test_dct_types, res[3:]): + self.assertIsInstance(r, expected_type) + self.assertEqual(3, len(r)) + self.assertEqual(a_val, r['a']) + self.assertEqual(b_val, r['b']) + self.assertEqual(c_val, r['c']) + self.assertEqual(res[5].default_factory, str) + # Tuple of lists, tuples, namedtuple, dict, frozendict, and defaultdict + res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, + c=c), dict(test_dct), + frozendict(test_dct), defaultdict(str, test_dct))) + self.assertIsInstance(res, tuple) + self.assertEqual(6, len(res)) + self.assertIsInstance(res[0], list) self.assertEqual(3, len(res[0])) self.assertEqual(a_val, res[0][0]) self.assertEqual(b_val, res[0][1]) self.assertEqual(c_val, res[0][2]) - self.assertTrue(isinstance(res[1], tuple)) + self.assertIsInstance(res[1], tuple) self.assertEqual(3, len(res[1])) self.assertEqual(a_val, res[1][0]) self.assertEqual(b_val, res[1][1]) self.assertEqual(c_val, res[1][2]) - self.assertTrue(isinstance(res[2], ABC)) + self.assertIsInstance(res[2], ABC) self.assertEqual(a_val, res[2].a) self.assertEqual(b_val, res[2].b) self.assertEqual(c_val, res[2].c) - self.assertTrue(isinstance(res[3], dict)) - self.assertEqual(3, len(res[3])) - self.assertEqual(a_val, res[3]['a']) - self.assertEqual(b_val, res[3]['b']) - self.assertEqual(c_val, res[3]['c']) - # Namedtuple of lists, tuples, namedtuples, and dict + for expected_type, r in zip(test_dct_types, res[3:]): + self.assertIsInstance(r, expected_type) + self.assertEqual(3, len(r)) + self.assertEqual(a_val, r['a']) + self.assertEqual(b_val, r['b']) + self.assertEqual(c_val, r['c']) + self.assertEqual(res[5].default_factory, str) + + # Namedtuple of lists, tuples, namedtuples, dict, frozendict, defaultdict res = sess.run( - DEFG( + DEFGHI( d=[a, b, c], e=(a, b, c), f=ABC(a=a.name, b=b, c=c), - g={ - 'a': a, - 'c': c, - 'b': b - })) - self.assertTrue(isinstance(res, DEFG)) - self.assertTrue(isinstance(res.d, list)) + g=dict(test_dct), + h=frozendict(test_dct), + i=defaultdict(str, test_dct))) + self.assertIsInstance(res, DEFGHI) + self.assertIsInstance(res.d, list) self.assertEqual(3, len(res.d)) self.assertEqual(a_val, res.d[0]) self.assertEqual(b_val, res.d[1]) self.assertEqual(c_val, res.d[2]) - self.assertTrue(isinstance(res.e, tuple)) + self.assertIsInstance(res.e, tuple) self.assertEqual(3, len(res.e)) self.assertEqual(a_val, res.e[0]) self.assertEqual(b_val, res.e[1]) self.assertEqual(c_val, res.e[2]) - self.assertTrue(isinstance(res.f, ABC)) + self.assertIsInstance(res.f, ABC) self.assertEqual(a_val, res.f.a) self.assertEqual(b_val, res.f.b) self.assertEqual(c_val, res.f.c) - self.assertTrue(isinstance(res.g, dict)) + self.assertIsInstance(res.g, dict) self.assertEqual(3, len(res.g)) self.assertEqual(a_val, res.g['a']) self.assertEqual(b_val, res.g['b']) self.assertEqual(c_val, res.g['c']) - # Dict of lists, tuples, namedtuples, and dict + self.assertIsInstance(res.h, frozendict) + self.assertEqual(3, len(res.h)) + self.assertEqual(a_val, res.h['a']) + self.assertEqual(b_val, res.h['b']) + self.assertEqual(c_val, res.h['c']) + self.assertIsInstance(res.i, defaultdict) + self.assertEqual(3, len(res.i)) + self.assertEqual(a_val, res.i['a']) + self.assertEqual(b_val, res.i['b']) + self.assertEqual(c_val, res.i['c']) + self.assertEqual(res.i.default_factory, str) + # Dict of lists, tuples, namedtuples, dict, frozendict, defaultdict res = sess.run({ 'd': [a, b, c], 'e': (a, b, c), 'f': ABC(a=a, b=b, c=c), - 'g': { - 'a': a.name, - 'c': c, - 'b': b - } + 'g': dict(test_dct), + 'h': frozendict(test_dct), + 'i': defaultdict(str, test_dct), }) - self.assertTrue(isinstance(res, dict)) - self.assertEqual(4, len(res)) - self.assertTrue(isinstance(res['d'], list)) + self.assertIsInstance(res, dict) + self.assertEqual(6, len(res)) + self.assertIsInstance(res['d'], list) self.assertEqual(3, len(res['d'])) self.assertEqual(a_val, res['d'][0]) self.assertEqual(b_val, res['d'][1]) self.assertEqual(c_val, res['d'][2]) - self.assertTrue(isinstance(res['e'], tuple)) + self.assertIsInstance(res['e'], tuple) self.assertEqual(3, len(res['e'])) self.assertEqual(a_val, res['e'][0]) self.assertEqual(b_val, res['e'][1]) self.assertEqual(c_val, res['e'][2]) - self.assertTrue(isinstance(res['f'], ABC)) + self.assertIsInstance(res['f'], ABC) self.assertEqual(a_val, res['f'].a) self.assertEqual(b_val, res['f'].b) self.assertEqual(c_val, res['f'].c) - self.assertTrue(isinstance(res['g'], dict)) - self.assertEqual(3, len(res['g'])) - self.assertEqual(a_val, res['g']['a']) - self.assertEqual(b_val, res['g']['b']) - self.assertEqual(c_val, res['g']['c']) + for expected_type, r_key in zip(test_dct_types, ('g', 'h', 'i')): + r = res[r_key] + self.assertIsInstance(r, expected_type) + self.assertEqual(3, len(r)) + self.assertEqual(a_val, r['a']) + self.assertEqual(b_val, r['b']) + self.assertEqual(c_val, r['c']) + self.assertEqual(res['i'].default_factory, str) def testFetchTensorObject(self): with session.Session() as s: @@ -1279,7 +1303,7 @@ class SessionTest(test_util.TensorFlowTestCase): @test_util.run_v1_only('b/120545219') def testNotEntered(self): # pylint: disable=protected-access - self.assertEqual(ops._default_session_stack.get_default(), None) + self.assertIsNone(ops._default_session_stack.get_default()) # pylint: enable=protected-access with ops.device('/cpu:0'): sess = session.Session() @@ -1326,10 +1350,10 @@ class SessionTest(test_util.TensorFlowTestCase): with warnings.catch_warnings(record=True) as w: sess2 = session.InteractiveSession() self.assertEqual(1, len(w)) - self.assertTrue('An interactive session is already active. This can cause ' - 'out-of-memory errors in some cases. You must explicitly ' - 'call `InteractiveSession.close()` to release resources ' - 'held by the other session(s).' in str(w[0].message)) + self.assertIn('An interactive session is already active. This can cause ' + 'out-of-memory errors in some cases. You must explicitly ' + 'call `InteractiveSession.close()` to release resources ' + 'held by the other session(s).', str(w[0].message)) sess2.close() sess.close() @@ -1610,10 +1634,10 @@ class SessionTest(test_util.TensorFlowTestCase): e = constant_op.constant(44.0, name=b'e') f = constant_op.constant(45.0, name=r'f') - self.assertTrue(isinstance(c.name, six.text_type)) - self.assertTrue(isinstance(d.name, six.text_type)) - self.assertTrue(isinstance(e.name, six.text_type)) - self.assertTrue(isinstance(f.name, six.text_type)) + self.assertIsInstance(c.name, six.text_type) + self.assertIsInstance(d.name, six.text_type) + self.assertIsInstance(e.name, six.text_type) + self.assertIsInstance(f.name, six.text_type) self.assertEqual(42.0, sess.run('c:0')) self.assertEqual(42.0, sess.run(u'c:0')) @@ -1673,10 +1697,10 @@ class SessionTest(test_util.TensorFlowTestCase): with ops.device('/cpu:0'): with session.Session() as sess: sess.run(constant_op.constant(1.0)) - self.assertTrue(not run_metadata.HasField('step_stats')) + self.assertFalse(run_metadata.HasField('step_stats')) sess.run(constant_op.constant(1.0), run_metadata=run_metadata) - self.assertTrue(not run_metadata.HasField('step_stats')) + self.assertFalse(run_metadata.HasField('step_stats')) sess.run( constant_op.constant(1.0), @@ -1697,11 +1721,11 @@ class SessionTest(test_util.TensorFlowTestCase): sess.run(constant_op.constant(1.0), options=None, run_metadata=None) sess.run( constant_op.constant(1.0), options=None, run_metadata=run_metadata) - self.assertTrue(not run_metadata.HasField('step_stats')) + self.assertFalse(run_metadata.HasField('step_stats')) sess.run( constant_op.constant(1.0), options=run_options, run_metadata=None) - self.assertTrue(not run_metadata.HasField('step_stats')) + self.assertFalse(run_metadata.HasField('step_stats')) sess.run( constant_op.constant(1.0), @@ -1730,9 +1754,9 @@ class SessionTest(test_util.TensorFlowTestCase): with ops.Graph().as_default(), ops.device('/cpu:0'): a = constant_op.constant([[1, 2]]) sess = session.Session() - self.assertFalse('_output_shapes' in sess.graph_def.node[0].attr) + self.assertNotIn('_output_shapes', sess.graph_def.node[0].attr) # Avoid lint error regarding 'unused' var a. - self.assertTrue(a == a) + self.assertEqual(a, a) def testInferShapesTrue(self): config_pb = config_pb2.ConfigProto( @@ -1740,9 +1764,9 @@ class SessionTest(test_util.TensorFlowTestCase): with ops.Graph().as_default(), ops.device('/cpu:0'): a = constant_op.constant([[1, 2]]) sess = session.Session(config=config_pb) - self.assertTrue('_output_shapes' in sess.graph_def.node[0].attr) + self.assertIn('_output_shapes', sess.graph_def.node[0].attr) # Avoid lint error regarding 'unused' var a. - self.assertTrue(a == a) + self.assertEqual(a, a) def testBuildCostModel(self): run_options = config_pb2.RunOptions()