Distributions should raise the original exception (log_prob not implemented) instead of the fallback exception (prob not implemented).

Additionally, in a nested structure of transformed distributions, it can be useful to know which distribution is raising this error.

PiperOrigin-RevId: 213618306
This commit is contained in:
Brian Patton 2018-09-19 06:37:43 -07:00 committed by TensorFlower Gardener
parent 22ff5db5d4
commit 7bc9f39687

View File

@ -601,7 +601,8 @@ class Distribution(_BaseDistribution):
return type(self)(**parameters)
def _batch_shape_tensor(self):
raise NotImplementedError("batch_shape_tensor is not implemented")
raise NotImplementedError(
"batch_shape_tensor is not implemented: {}".format(type(self).__name__))
def batch_shape_tensor(self, name="batch_shape_tensor"):
"""Shape of a single sample from a single event index as a 1-D `Tensor`.
@ -640,7 +641,8 @@ class Distribution(_BaseDistribution):
return tensor_shape.as_shape(self._batch_shape())
def _event_shape_tensor(self):
raise NotImplementedError("event_shape_tensor is not implemented")
raise NotImplementedError(
"event_shape_tensor is not implemented: {}".format(type(self).__name__))
def event_shape_tensor(self, name="event_shape_tensor"):
"""Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
@ -701,7 +703,8 @@ class Distribution(_BaseDistribution):
name="is_scalar_batch")
def _sample_n(self, n, seed=None):
raise NotImplementedError("sample_n is not implemented")
raise NotImplementedError("sample_n is not implemented: {}".format(
type(self).__name__))
def _call_sample_n(self, sample_shape, seed, name, **kwargs):
with self._name_scope(name, values=[sample_shape]):
@ -733,15 +736,19 @@ class Distribution(_BaseDistribution):
return self._call_sample_n(sample_shape, seed, name)
def _log_prob(self, value):
raise NotImplementedError("log_prob is not implemented")
raise NotImplementedError("log_prob is not implemented: {}".format(
type(self).__name__))
def _call_log_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_prob(value, **kwargs)
except NotImplementedError:
return math_ops.log(self._prob(value, **kwargs))
except NotImplementedError as original_exception:
try:
return math_ops.log(self._prob(value, **kwargs))
except NotImplementedError:
raise original_exception
def log_prob(self, value, name="log_prob"):
"""Log probability density/mass function.
@ -757,15 +764,19 @@ class Distribution(_BaseDistribution):
return self._call_log_prob(value, name)
def _prob(self, value):
raise NotImplementedError("prob is not implemented")
raise NotImplementedError("prob is not implemented: {}".format(
type(self).__name__))
def _call_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._prob(value, **kwargs)
except NotImplementedError:
return math_ops.exp(self._log_prob(value, **kwargs))
except NotImplementedError as original_exception:
try:
return math_ops.exp(self._log_prob(value, **kwargs))
except NotImplementedError:
raise original_exception
def prob(self, value, name="prob"):
"""Probability density/mass function.
@ -781,15 +792,19 @@ class Distribution(_BaseDistribution):
return self._call_prob(value, name)
def _log_cdf(self, value):
raise NotImplementedError("log_cdf is not implemented")
raise NotImplementedError("log_cdf is not implemented: {}".format(
type(self).__name__))
def _call_log_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_cdf(value, **kwargs)
except NotImplementedError:
return math_ops.log(self._cdf(value, **kwargs))
except NotImplementedError as original_exception:
try:
return math_ops.log(self._cdf(value, **kwargs))
except NotImplementedError:
raise original_exception
def log_cdf(self, value, name="log_cdf"):
"""Log cumulative distribution function.
@ -815,15 +830,19 @@ class Distribution(_BaseDistribution):
return self._call_log_cdf(value, name)
def _cdf(self, value):
raise NotImplementedError("cdf is not implemented")
raise NotImplementedError("cdf is not implemented: {}".format(
type(self).__name__))
def _call_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._cdf(value, **kwargs)
except NotImplementedError:
return math_ops.exp(self._log_cdf(value, **kwargs))
except NotImplementedError as original_exception:
try:
return math_ops.exp(self._log_cdf(value, **kwargs))
except NotImplementedError:
raise original_exception
def cdf(self, value, name="cdf"):
"""Cumulative distribution function.
@ -845,15 +864,20 @@ class Distribution(_BaseDistribution):
return self._call_cdf(value, name)
def _log_survival_function(self, value):
raise NotImplementedError("log_survival_function is not implemented")
raise NotImplementedError(
"log_survival_function is not implemented: {}".format(
type(self).__name__))
def _call_log_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_survival_function(value, **kwargs)
except NotImplementedError:
return math_ops.log1p(-self.cdf(value, **kwargs))
except NotImplementedError as original_exception:
try:
return math_ops.log1p(-self.cdf(value, **kwargs))
except NotImplementedError:
raise original_exception
def log_survival_function(self, value, name="log_survival_function"):
"""Log survival function.
@ -880,15 +904,19 @@ class Distribution(_BaseDistribution):
return self._call_log_survival_function(value, name)
def _survival_function(self, value):
raise NotImplementedError("survival_function is not implemented")
raise NotImplementedError("survival_function is not implemented: {}".format(
type(self).__name__))
def _call_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._survival_function(value, **kwargs)
except NotImplementedError:
return 1. - self.cdf(value, **kwargs)
except NotImplementedError as original_exception:
try:
return 1. - self.cdf(value, **kwargs)
except NotImplementedError:
raise original_exception
def survival_function(self, value, name="survival_function"):
"""Survival function.
@ -912,7 +940,8 @@ class Distribution(_BaseDistribution):
return self._call_survival_function(value, name)
def _entropy(self):
raise NotImplementedError("entropy is not implemented")
raise NotImplementedError("entropy is not implemented: {}".format(
type(self).__name__))
def entropy(self, name="entropy"):
"""Shannon entropy in nats."""
@ -920,7 +949,8 @@ class Distribution(_BaseDistribution):
return self._entropy()
def _mean(self):
raise NotImplementedError("mean is not implemented")
raise NotImplementedError("mean is not implemented: {}".format(
type(self).__name__))
def mean(self, name="mean"):
"""Mean."""
@ -928,7 +958,8 @@ class Distribution(_BaseDistribution):
return self._mean()
def _quantile(self, value):
raise NotImplementedError("quantile is not implemented")
raise NotImplementedError("quantile is not implemented: {}".format(
type(self).__name__))
def _call_quantile(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
@ -955,7 +986,8 @@ class Distribution(_BaseDistribution):
return self._call_quantile(value, name)
def _variance(self):
raise NotImplementedError("variance is not implemented")
raise NotImplementedError("variance is not implemented: {}".format(
type(self).__name__))
def variance(self, name="variance"):
"""Variance.
@ -979,11 +1011,15 @@ class Distribution(_BaseDistribution):
with self._name_scope(name):
try:
return self._variance()
except NotImplementedError:
return math_ops.square(self._stddev())
except NotImplementedError as original_exception:
try:
return math_ops.square(self._stddev())
except NotImplementedError:
raise original_exception
def _stddev(self):
raise NotImplementedError("stddev is not implemented")
raise NotImplementedError("stddev is not implemented: {}".format(
type(self).__name__))
def stddev(self, name="stddev"):
"""Standard deviation.
@ -1008,11 +1044,15 @@ class Distribution(_BaseDistribution):
with self._name_scope(name):
try:
return self._stddev()
except NotImplementedError:
return math_ops.sqrt(self._variance())
except NotImplementedError as original_exception:
try:
return math_ops.sqrt(self._variance())
except NotImplementedError:
raise original_exception
def _covariance(self):
raise NotImplementedError("covariance is not implemented")
raise NotImplementedError("covariance is not implemented: {}".format(
type(self).__name__))
def covariance(self, name="covariance"):
"""Covariance.
@ -1054,7 +1094,8 @@ class Distribution(_BaseDistribution):
return self._covariance()
def _mode(self):
raise NotImplementedError("mode is not implemented")
raise NotImplementedError("mode is not implemented: {}".format(
type(self).__name__))
def mode(self, name="mode"):
"""Mode."""