Use "is" instead of ==, and Reference for dict keys {} in Bijector.

PiperOrigin-RevId: 263629253
This commit is contained in:
Martin Wicke 2019-08-15 13:26:49 -07:00 committed by TensorFlower Gardener
parent 053f39e766
commit f7f6f8655a
3 changed files with 15 additions and 14 deletions

View File

@ -116,6 +116,7 @@ class BrokenBijector(bijector.Bijector):
raise IntentionallyMissingError raise IntentionallyMissingError
return math_ops.log(2.) return math_ops.log(2.)
class BijectorTestEventNdims(test.TestCase): class BijectorTestEventNdims(test.TestCase):
def testBijectorNonIntegerEventNdims(self): def testBijectorNonIntegerEventNdims(self):
@ -162,12 +163,8 @@ class BijectorCachingTestBase(object):
_ = broken_bijector.forward_log_det_jacobian(x, event_ndims=0) _ = broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
# Now, everything should be cached if the argument is y. # Now, everything should be cached if the argument is y.
broken_bijector.inverse(y)
broken_bijector.inverse_log_det_jacobian(y, event_ndims=0) broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
try:
broken_bijector.inverse(y)
broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
except IntentionallyMissingError:
raise AssertionError("Tests failed! Cached values not used.")
# Different event_ndims should not be cached. # Different event_ndims should not be cached.
with self.assertRaises(IntentionallyMissingError): with self.assertRaises(IntentionallyMissingError):
@ -182,11 +179,8 @@ class BijectorCachingTestBase(object):
_ = broken_bijector.inverse_log_det_jacobian(y, event_ndims=0) _ = broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
# Now, everything should be cached if the argument is x. # Now, everything should be cached if the argument is x.
try: broken_bijector.forward(x)
broken_bijector.forward(x) broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
except IntentionallyMissingError:
raise AssertionError("Tests failed! Cached values not used.")
# Different event_ndims should not be cached. # Different event_ndims should not be cached.
with self.assertRaises(IntentionallyMissingError): with self.assertRaises(IntentionallyMissingError):

View File

@ -34,6 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util import object_identity
__all__ = [ __all__ = [
@ -64,12 +65,14 @@ class _Mapping(collections.namedtuple(
@property @property
def x_key(self): def x_key(self):
"""Returns key used for caching Y=g(X).""" """Returns key used for caching Y=g(X)."""
return (self.x,) + self._deep_tuple(tuple(sorted(self.kwargs.items()))) return ((object_identity.Reference(self.x),) +
self._deep_tuple(tuple(sorted(self.kwargs.items()))))
@property @property
def y_key(self): def y_key(self):
"""Returns key used for caching X=g^{-1}(Y).""" """Returns key used for caching X=g^{-1}(Y)."""
return (self.y,) + self._deep_tuple(tuple(sorted(self.kwargs.items()))) return ((object_identity.Reference(self.y),) +
self._deep_tuple(tuple(sorted(self.kwargs.items()))))
def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None): def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None):
"""Returns new _Mapping with args merged with self. """Returns new _Mapping with args merged with self.
@ -108,7 +111,7 @@ class _Mapping(collections.namedtuple(
new = {} if new is None else new new = {} if new is None else new
for k, v in six.iteritems(new): for k, v in six.iteritems(new):
val = old.get(k, None) val = old.get(k, None)
if val is not None and val != v: if val is not None and val is not v:
raise ValueError("Found different value for existing key " raise ValueError("Found different value for existing key "
"(key:{} old_value:{} new_value:{}".format( "(key:{} old_value:{} new_value:{}".format(
k, old[k], v)) k, old[k], v))
@ -119,7 +122,7 @@ class _Mapping(collections.namedtuple(
"""Helper to merge which handles merging one value.""" """Helper to merge which handles merging one value."""
if old is None: if old is None:
return new return new
elif new is not None and old != new: elif new is not None and old is not new:
raise ValueError("Incompatible values: %s != %s" % (old, new)) raise ValueError("Incompatible values: %s != %s" % (old, new))
return old return old
@ -567,6 +570,7 @@ class Bijector(object):
self._constant_ildj_map = {} self._constant_ildj_map = {}
self._validate_args = validate_args self._validate_args = validate_args
self._dtype = dtype self._dtype = dtype
# These dicts can only be accessed using _Mapping.x_key or _Mapping.y_key
self._from_y = {} self._from_y = {}
self._from_x = {} self._from_x = {}
if name: if name:

View File

@ -129,6 +129,9 @@ class ObjectIdentityDictionary(collections_abc.MutableMapping):
for key in self._storage: for key in self._storage:
yield key.unwrapped yield key.unwrapped
def __repr__(self):
return "ObjectIdentityDictionary(%s)" % repr(self._storage)
class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary): class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
"""Like weakref.WeakKeyDictionary, but compares objects with "is".""" """Like weakref.WeakKeyDictionary, but compares objects with "is"."""