Use "is" instead of ==, and Reference for dict keys {} in Bijector.
PiperOrigin-RevId: 263629253
This commit is contained in:
parent
053f39e766
commit
f7f6f8655a
@ -116,6 +116,7 @@ class BrokenBijector(bijector.Bijector):
|
|||||||
raise IntentionallyMissingError
|
raise IntentionallyMissingError
|
||||||
return math_ops.log(2.)
|
return math_ops.log(2.)
|
||||||
|
|
||||||
|
|
||||||
class BijectorTestEventNdims(test.TestCase):
|
class BijectorTestEventNdims(test.TestCase):
|
||||||
|
|
||||||
def testBijectorNonIntegerEventNdims(self):
|
def testBijectorNonIntegerEventNdims(self):
|
||||||
@ -162,12 +163,8 @@ class BijectorCachingTestBase(object):
|
|||||||
_ = broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
|
_ = broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
|
||||||
|
|
||||||
# Now, everything should be cached if the argument is y.
|
# Now, everything should be cached if the argument is y.
|
||||||
|
broken_bijector.inverse(y)
|
||||||
broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
|
broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
|
||||||
try:
|
|
||||||
broken_bijector.inverse(y)
|
|
||||||
broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
|
|
||||||
except IntentionallyMissingError:
|
|
||||||
raise AssertionError("Tests failed! Cached values not used.")
|
|
||||||
|
|
||||||
# Different event_ndims should not be cached.
|
# Different event_ndims should not be cached.
|
||||||
with self.assertRaises(IntentionallyMissingError):
|
with self.assertRaises(IntentionallyMissingError):
|
||||||
@ -182,11 +179,8 @@ class BijectorCachingTestBase(object):
|
|||||||
_ = broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
|
_ = broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
|
||||||
|
|
||||||
# Now, everything should be cached if the argument is x.
|
# Now, everything should be cached if the argument is x.
|
||||||
try:
|
broken_bijector.forward(x)
|
||||||
broken_bijector.forward(x)
|
broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
|
||||||
broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
|
|
||||||
except IntentionallyMissingError:
|
|
||||||
raise AssertionError("Tests failed! Cached values not used.")
|
|
||||||
|
|
||||||
# Different event_ndims should not be cached.
|
# Different event_ndims should not be cached.
|
||||||
with self.assertRaises(IntentionallyMissingError):
|
with self.assertRaises(IntentionallyMissingError):
|
||||||
|
@ -34,6 +34,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.distributions import util as distribution_util
|
from tensorflow.python.ops.distributions import util as distribution_util
|
||||||
|
from tensorflow.python.util import object_identity
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -64,12 +65,14 @@ class _Mapping(collections.namedtuple(
|
|||||||
@property
|
@property
|
||||||
def x_key(self):
|
def x_key(self):
|
||||||
"""Returns key used for caching Y=g(X)."""
|
"""Returns key used for caching Y=g(X)."""
|
||||||
return (self.x,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
|
return ((object_identity.Reference(self.x),) +
|
||||||
|
self._deep_tuple(tuple(sorted(self.kwargs.items()))))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def y_key(self):
|
def y_key(self):
|
||||||
"""Returns key used for caching X=g^{-1}(Y)."""
|
"""Returns key used for caching X=g^{-1}(Y)."""
|
||||||
return (self.y,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
|
return ((object_identity.Reference(self.y),) +
|
||||||
|
self._deep_tuple(tuple(sorted(self.kwargs.items()))))
|
||||||
|
|
||||||
def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None):
|
def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None):
|
||||||
"""Returns new _Mapping with args merged with self.
|
"""Returns new _Mapping with args merged with self.
|
||||||
@ -108,7 +111,7 @@ class _Mapping(collections.namedtuple(
|
|||||||
new = {} if new is None else new
|
new = {} if new is None else new
|
||||||
for k, v in six.iteritems(new):
|
for k, v in six.iteritems(new):
|
||||||
val = old.get(k, None)
|
val = old.get(k, None)
|
||||||
if val is not None and val != v:
|
if val is not None and val is not v:
|
||||||
raise ValueError("Found different value for existing key "
|
raise ValueError("Found different value for existing key "
|
||||||
"(key:{} old_value:{} new_value:{}".format(
|
"(key:{} old_value:{} new_value:{}".format(
|
||||||
k, old[k], v))
|
k, old[k], v))
|
||||||
@ -119,7 +122,7 @@ class _Mapping(collections.namedtuple(
|
|||||||
"""Helper to merge which handles merging one value."""
|
"""Helper to merge which handles merging one value."""
|
||||||
if old is None:
|
if old is None:
|
||||||
return new
|
return new
|
||||||
elif new is not None and old != new:
|
elif new is not None and old is not new:
|
||||||
raise ValueError("Incompatible values: %s != %s" % (old, new))
|
raise ValueError("Incompatible values: %s != %s" % (old, new))
|
||||||
return old
|
return old
|
||||||
|
|
||||||
@ -567,6 +570,7 @@ class Bijector(object):
|
|||||||
self._constant_ildj_map = {}
|
self._constant_ildj_map = {}
|
||||||
self._validate_args = validate_args
|
self._validate_args = validate_args
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
|
# These dicts can only be accessed using _Mapping.x_key or _Mapping.y_key
|
||||||
self._from_y = {}
|
self._from_y = {}
|
||||||
self._from_x = {}
|
self._from_x = {}
|
||||||
if name:
|
if name:
|
||||||
|
@ -129,6 +129,9 @@ class ObjectIdentityDictionary(collections_abc.MutableMapping):
|
|||||||
for key in self._storage:
|
for key in self._storage:
|
||||||
yield key.unwrapped
|
yield key.unwrapped
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "ObjectIdentityDictionary(%s)" % repr(self._storage)
|
||||||
|
|
||||||
|
|
||||||
class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
|
class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
|
||||||
"""Like weakref.WeakKeyDictionary, but compares objects with "is"."""
|
"""Like weakref.WeakKeyDictionary, but compares objects with "is"."""
|
||||||
|
Loading…
Reference in New Issue
Block a user