Use "is" instead of ==, and Reference for dict keys {} in Bijector.
PiperOrigin-RevId: 263629253
This commit is contained in:
parent
053f39e766
commit
f7f6f8655a
tensorflow/python
@ -116,6 +116,7 @@ class BrokenBijector(bijector.Bijector):
|
||||
raise IntentionallyMissingError
|
||||
return math_ops.log(2.)
|
||||
|
||||
|
||||
class BijectorTestEventNdims(test.TestCase):
|
||||
|
||||
def testBijectorNonIntegerEventNdims(self):
|
||||
@ -162,12 +163,8 @@ class BijectorCachingTestBase(object):
|
||||
_ = broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
|
||||
|
||||
# Now, everything should be cached if the argument is y.
|
||||
broken_bijector.inverse(y)
|
||||
broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
|
||||
try:
|
||||
broken_bijector.inverse(y)
|
||||
broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
|
||||
except IntentionallyMissingError:
|
||||
raise AssertionError("Tests failed! Cached values not used.")
|
||||
|
||||
# Different event_ndims should not be cached.
|
||||
with self.assertRaises(IntentionallyMissingError):
|
||||
@ -182,11 +179,8 @@ class BijectorCachingTestBase(object):
|
||||
_ = broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
|
||||
|
||||
# Now, everything should be cached if the argument is x.
|
||||
try:
|
||||
broken_bijector.forward(x)
|
||||
broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
|
||||
except IntentionallyMissingError:
|
||||
raise AssertionError("Tests failed! Cached values not used.")
|
||||
broken_bijector.forward(x)
|
||||
broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
|
||||
|
||||
# Different event_ndims should not be cached.
|
||||
with self.assertRaises(IntentionallyMissingError):
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.distributions import util as distribution_util
|
||||
from tensorflow.python.util import object_identity
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -64,12 +65,14 @@ class _Mapping(collections.namedtuple(
|
||||
@property
|
||||
def x_key(self):
|
||||
"""Returns key used for caching Y=g(X)."""
|
||||
return (self.x,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
|
||||
return ((object_identity.Reference(self.x),) +
|
||||
self._deep_tuple(tuple(sorted(self.kwargs.items()))))
|
||||
|
||||
@property
|
||||
def y_key(self):
|
||||
"""Returns key used for caching X=g^{-1}(Y)."""
|
||||
return (self.y,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
|
||||
return ((object_identity.Reference(self.y),) +
|
||||
self._deep_tuple(tuple(sorted(self.kwargs.items()))))
|
||||
|
||||
def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None):
|
||||
"""Returns new _Mapping with args merged with self.
|
||||
@ -108,7 +111,7 @@ class _Mapping(collections.namedtuple(
|
||||
new = {} if new is None else new
|
||||
for k, v in six.iteritems(new):
|
||||
val = old.get(k, None)
|
||||
if val is not None and val != v:
|
||||
if val is not None and val is not v:
|
||||
raise ValueError("Found different value for existing key "
|
||||
"(key:{} old_value:{} new_value:{}".format(
|
||||
k, old[k], v))
|
||||
@ -119,7 +122,7 @@ class _Mapping(collections.namedtuple(
|
||||
"""Helper to merge which handles merging one value."""
|
||||
if old is None:
|
||||
return new
|
||||
elif new is not None and old != new:
|
||||
elif new is not None and old is not new:
|
||||
raise ValueError("Incompatible values: %s != %s" % (old, new))
|
||||
return old
|
||||
|
||||
@ -567,6 +570,7 @@ class Bijector(object):
|
||||
self._constant_ildj_map = {}
|
||||
self._validate_args = validate_args
|
||||
self._dtype = dtype
|
||||
# These dicts can only be accessed using _Mapping.x_key or _Mapping.y_key
|
||||
self._from_y = {}
|
||||
self._from_x = {}
|
||||
if name:
|
||||
|
@ -129,6 +129,9 @@ class ObjectIdentityDictionary(collections_abc.MutableMapping):
|
||||
for key in self._storage:
|
||||
yield key.unwrapped
|
||||
|
||||
def __repr__(self):
|
||||
return "ObjectIdentityDictionary(%s)" % repr(self._storage)
|
||||
|
||||
|
||||
class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
|
||||
"""Like weakref.WeakKeyDictionary, but compares objects with "is"."""
|
||||
|
Loading…
Reference in New Issue
Block a user