Add entropy method to Deterministic.
PiperOrigin-RevId: 208657597
This commit is contained in:
parent
95839c96ae
commit
d7f93284c8
@ -173,6 +173,13 @@ class DeterministicTest(test.TestCase):
|
||||
self.assertAllClose(
|
||||
np.zeros(sample_shape_ + (2,)).astype(np.float32), sample_)
|
||||
|
||||
def testEntropy(self):
|
||||
loc = np.array([-0.1, -3.2, 7.])
|
||||
deterministic = deterministic_lib.Deterministic(loc=loc)
|
||||
with self.test_session() as sess:
|
||||
entropy_ = sess.run(deterministic.entropy())
|
||||
self.assertAllEqual(np.zeros(3), entropy_)
|
||||
|
||||
|
||||
class VectorDeterministicTest(test.TestCase):
|
||||
|
||||
@ -290,6 +297,13 @@ class VectorDeterministicTest(test.TestCase):
|
||||
self.assertAllClose(
|
||||
np.zeros(sample_shape_ + (2, 1)).astype(np.float32), sample_)
|
||||
|
||||
def testEntropy(self):
|
||||
loc = np.array([[8.3, 1.2, 3.3], [-0.1, -3.2, 7.]])
|
||||
deterministic = deterministic_lib.VectorDeterministic(loc=loc)
|
||||
with self.test_session() as sess:
|
||||
entropy_ = sess.run(deterministic.entropy())
|
||||
self.assertAllEqual(np.zeros(2), entropy_)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -152,6 +152,9 @@ class _BaseDeterministic(distribution.Distribution):
|
||||
"""Relative tolerance for comparing points to `self.loc`."""
|
||||
return self._rtol
|
||||
|
||||
def _entropy(self):
|
||||
return array_ops.zeros(self.batch_shape_tensor(), dtype=self.dtype)
|
||||
|
||||
def _mean(self):
|
||||
return array_ops.identity(self.loc)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user