Add entropy method to Deterministic.

PiperOrigin-RevId: 208657597
This commit is contained in:
Dustin Tran 2018-08-14 09:17:36 -07:00 committed by TensorFlower Gardener
parent 95839c96ae
commit d7f93284c8
2 changed files with 17 additions and 0 deletions

View File

@ -173,6 +173,13 @@ class DeterministicTest(test.TestCase):
self.assertAllClose(
np.zeros(sample_shape_ + (2,)).astype(np.float32), sample_)
def testEntropy(self):
loc = np.array([-0.1, -3.2, 7.])
deterministic = deterministic_lib.Deterministic(loc=loc)
with self.test_session() as sess:
entropy_ = sess.run(deterministic.entropy())
self.assertAllEqual(np.zeros(3), entropy_)
class VectorDeterministicTest(test.TestCase):
@ -290,6 +297,13 @@ class VectorDeterministicTest(test.TestCase):
self.assertAllClose(
np.zeros(sample_shape_ + (2, 1)).astype(np.float32), sample_)
def testEntropy(self):
loc = np.array([[8.3, 1.2, 3.3], [-0.1, -3.2, 7.]])
deterministic = deterministic_lib.VectorDeterministic(loc=loc)
with self.test_session() as sess:
entropy_ = sess.run(deterministic.entropy())
self.assertAllEqual(np.zeros(2), entropy_)
if __name__ == "__main__":
test.main()

View File

@ -152,6 +152,9 @@ class _BaseDeterministic(distribution.Distribution):
"""Relative tolerance for comparing points to `self.loc`."""
return self._rtol
def _entropy(self):
return array_ops.zeros(self.batch_shape_tensor(), dtype=self.dtype)
def _mean(self):
return array_ops.identity(self.loc)