Use "is" instead of ==, and Reference for dict keys {} in Bijector.

PiperOrigin-RevId: 263629253
This commit is contained in:
Martin Wicke 2019-08-15 13:26:49 -07:00 committed by TensorFlower Gardener
parent 053f39e766
commit f7f6f8655a
3 changed files with 15 additions and 14 deletions
tensorflow/python
kernel_tests/distributions
ops/distributions
util

View File

@ -116,6 +116,7 @@ class BrokenBijector(bijector.Bijector):
raise IntentionallyMissingError
return math_ops.log(2.)
class BijectorTestEventNdims(test.TestCase):
def testBijectorNonIntegerEventNdims(self):
@ -162,12 +163,8 @@ class BijectorCachingTestBase(object):
_ = broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
# Now, everything should be cached if the argument is y.
broken_bijector.inverse(y)
broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
try:
broken_bijector.inverse(y)
broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
except IntentionallyMissingError:
raise AssertionError("Tests failed! Cached values not used.")
# Different event_ndims should not be cached.
with self.assertRaises(IntentionallyMissingError):
@ -182,11 +179,8 @@ class BijectorCachingTestBase(object):
_ = broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
# Now, everything should be cached if the argument is x.
try:
broken_bijector.forward(x)
broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
except IntentionallyMissingError:
raise AssertionError("Tests failed! Cached values not used.")
broken_bijector.forward(x)
broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
# Different event_ndims should not be cached.
with self.assertRaises(IntentionallyMissingError):

View File

@ -34,6 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util import object_identity
__all__ = [
@ -64,12 +65,14 @@ class _Mapping(collections.namedtuple(
@property
def x_key(self):
"""Returns key used for caching Y=g(X)."""
return (self.x,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
return ((object_identity.Reference(self.x),) +
self._deep_tuple(tuple(sorted(self.kwargs.items()))))
@property
def y_key(self):
"""Returns key used for caching X=g^{-1}(Y)."""
return (self.y,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
return ((object_identity.Reference(self.y),) +
self._deep_tuple(tuple(sorted(self.kwargs.items()))))
def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None):
"""Returns new _Mapping with args merged with self.
@ -108,7 +111,7 @@ class _Mapping(collections.namedtuple(
new = {} if new is None else new
for k, v in six.iteritems(new):
val = old.get(k, None)
if val is not None and val != v:
if val is not None and val is not v:
raise ValueError("Found different value for existing key "
"(key:{} old_value:{} new_value:{}".format(
k, old[k], v))
@ -119,7 +122,7 @@ class _Mapping(collections.namedtuple(
"""Helper to merge which handles merging one value."""
if old is None:
return new
elif new is not None and old != new:
elif new is not None and old is not new:
raise ValueError("Incompatible values: %s != %s" % (old, new))
return old
@ -567,6 +570,7 @@ class Bijector(object):
self._constant_ildj_map = {}
self._validate_args = validate_args
self._dtype = dtype
# These dicts can only be accessed using _Mapping.x_key or _Mapping.y_key
self._from_y = {}
self._from_x = {}
if name:

View File

@ -129,6 +129,9 @@ class ObjectIdentityDictionary(collections_abc.MutableMapping):
for key in self._storage:
yield key.unwrapped
def __repr__(self):
return "ObjectIdentityDictionary(%s)" % repr(self._storage)
class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
"""Like weakref.WeakKeyDictionary, but compares objects with "is"."""