Distributions should raise the original exception (log_prob not implemented) instead of the fallback exception (prob not implemented).
Additionally, in a nested structure of transformed distributions, it can be useful to know which distribution is raising this error. PiperOrigin-RevId: 213618306
This commit is contained in:
parent
22ff5db5d4
commit
7bc9f39687
@ -601,7 +601,8 @@ class Distribution(_BaseDistribution):
|
||||
return type(self)(**parameters)
|
||||
|
||||
def _batch_shape_tensor(self):
|
||||
raise NotImplementedError("batch_shape_tensor is not implemented")
|
||||
raise NotImplementedError(
|
||||
"batch_shape_tensor is not implemented: {}".format(type(self).__name__))
|
||||
|
||||
def batch_shape_tensor(self, name="batch_shape_tensor"):
|
||||
"""Shape of a single sample from a single event index as a 1-D `Tensor`.
|
||||
@ -640,7 +641,8 @@ class Distribution(_BaseDistribution):
|
||||
return tensor_shape.as_shape(self._batch_shape())
|
||||
|
||||
def _event_shape_tensor(self):
|
||||
raise NotImplementedError("event_shape_tensor is not implemented")
|
||||
raise NotImplementedError(
|
||||
"event_shape_tensor is not implemented: {}".format(type(self).__name__))
|
||||
|
||||
def event_shape_tensor(self, name="event_shape_tensor"):
|
||||
"""Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
|
||||
@ -701,7 +703,8 @@ class Distribution(_BaseDistribution):
|
||||
name="is_scalar_batch")
|
||||
|
||||
def _sample_n(self, n, seed=None):
|
||||
raise NotImplementedError("sample_n is not implemented")
|
||||
raise NotImplementedError("sample_n is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def _call_sample_n(self, sample_shape, seed, name, **kwargs):
|
||||
with self._name_scope(name, values=[sample_shape]):
|
||||
@ -733,15 +736,19 @@ class Distribution(_BaseDistribution):
|
||||
return self._call_sample_n(sample_shape, seed, name)
|
||||
|
||||
def _log_prob(self, value):
|
||||
raise NotImplementedError("log_prob is not implemented")
|
||||
raise NotImplementedError("log_prob is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def _call_log_prob(self, value, name, **kwargs):
|
||||
with self._name_scope(name, values=[value]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
try:
|
||||
return self._log_prob(value, **kwargs)
|
||||
except NotImplementedError:
|
||||
return math_ops.log(self._prob(value, **kwargs))
|
||||
except NotImplementedError as original_exception:
|
||||
try:
|
||||
return math_ops.log(self._prob(value, **kwargs))
|
||||
except NotImplementedError:
|
||||
raise original_exception
|
||||
|
||||
def log_prob(self, value, name="log_prob"):
|
||||
"""Log probability density/mass function.
|
||||
@ -757,15 +764,19 @@ class Distribution(_BaseDistribution):
|
||||
return self._call_log_prob(value, name)
|
||||
|
||||
def _prob(self, value):
|
||||
raise NotImplementedError("prob is not implemented")
|
||||
raise NotImplementedError("prob is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def _call_prob(self, value, name, **kwargs):
|
||||
with self._name_scope(name, values=[value]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
try:
|
||||
return self._prob(value, **kwargs)
|
||||
except NotImplementedError:
|
||||
return math_ops.exp(self._log_prob(value, **kwargs))
|
||||
except NotImplementedError as original_exception:
|
||||
try:
|
||||
return math_ops.exp(self._log_prob(value, **kwargs))
|
||||
except NotImplementedError:
|
||||
raise original_exception
|
||||
|
||||
def prob(self, value, name="prob"):
|
||||
"""Probability density/mass function.
|
||||
@ -781,15 +792,19 @@ class Distribution(_BaseDistribution):
|
||||
return self._call_prob(value, name)
|
||||
|
||||
def _log_cdf(self, value):
|
||||
raise NotImplementedError("log_cdf is not implemented")
|
||||
raise NotImplementedError("log_cdf is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def _call_log_cdf(self, value, name, **kwargs):
|
||||
with self._name_scope(name, values=[value]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
try:
|
||||
return self._log_cdf(value, **kwargs)
|
||||
except NotImplementedError:
|
||||
return math_ops.log(self._cdf(value, **kwargs))
|
||||
except NotImplementedError as original_exception:
|
||||
try:
|
||||
return math_ops.log(self._cdf(value, **kwargs))
|
||||
except NotImplementedError:
|
||||
raise original_exception
|
||||
|
||||
def log_cdf(self, value, name="log_cdf"):
|
||||
"""Log cumulative distribution function.
|
||||
@ -815,15 +830,19 @@ class Distribution(_BaseDistribution):
|
||||
return self._call_log_cdf(value, name)
|
||||
|
||||
def _cdf(self, value):
|
||||
raise NotImplementedError("cdf is not implemented")
|
||||
raise NotImplementedError("cdf is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def _call_cdf(self, value, name, **kwargs):
|
||||
with self._name_scope(name, values=[value]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
try:
|
||||
return self._cdf(value, **kwargs)
|
||||
except NotImplementedError:
|
||||
return math_ops.exp(self._log_cdf(value, **kwargs))
|
||||
except NotImplementedError as original_exception:
|
||||
try:
|
||||
return math_ops.exp(self._log_cdf(value, **kwargs))
|
||||
except NotImplementedError:
|
||||
raise original_exception
|
||||
|
||||
def cdf(self, value, name="cdf"):
|
||||
"""Cumulative distribution function.
|
||||
@ -845,15 +864,20 @@ class Distribution(_BaseDistribution):
|
||||
return self._call_cdf(value, name)
|
||||
|
||||
def _log_survival_function(self, value):
|
||||
raise NotImplementedError("log_survival_function is not implemented")
|
||||
raise NotImplementedError(
|
||||
"log_survival_function is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def _call_log_survival_function(self, value, name, **kwargs):
|
||||
with self._name_scope(name, values=[value]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
try:
|
||||
return self._log_survival_function(value, **kwargs)
|
||||
except NotImplementedError:
|
||||
return math_ops.log1p(-self.cdf(value, **kwargs))
|
||||
except NotImplementedError as original_exception:
|
||||
try:
|
||||
return math_ops.log1p(-self.cdf(value, **kwargs))
|
||||
except NotImplementedError:
|
||||
raise original_exception
|
||||
|
||||
def log_survival_function(self, value, name="log_survival_function"):
|
||||
"""Log survival function.
|
||||
@ -880,15 +904,19 @@ class Distribution(_BaseDistribution):
|
||||
return self._call_log_survival_function(value, name)
|
||||
|
||||
def _survival_function(self, value):
|
||||
raise NotImplementedError("survival_function is not implemented")
|
||||
raise NotImplementedError("survival_function is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def _call_survival_function(self, value, name, **kwargs):
|
||||
with self._name_scope(name, values=[value]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
try:
|
||||
return self._survival_function(value, **kwargs)
|
||||
except NotImplementedError:
|
||||
return 1. - self.cdf(value, **kwargs)
|
||||
except NotImplementedError as original_exception:
|
||||
try:
|
||||
return 1. - self.cdf(value, **kwargs)
|
||||
except NotImplementedError:
|
||||
raise original_exception
|
||||
|
||||
def survival_function(self, value, name="survival_function"):
|
||||
"""Survival function.
|
||||
@ -912,7 +940,8 @@ class Distribution(_BaseDistribution):
|
||||
return self._call_survival_function(value, name)
|
||||
|
||||
def _entropy(self):
|
||||
raise NotImplementedError("entropy is not implemented")
|
||||
raise NotImplementedError("entropy is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def entropy(self, name="entropy"):
|
||||
"""Shannon entropy in nats."""
|
||||
@ -920,7 +949,8 @@ class Distribution(_BaseDistribution):
|
||||
return self._entropy()
|
||||
|
||||
def _mean(self):
|
||||
raise NotImplementedError("mean is not implemented")
|
||||
raise NotImplementedError("mean is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def mean(self, name="mean"):
|
||||
"""Mean."""
|
||||
@ -928,7 +958,8 @@ class Distribution(_BaseDistribution):
|
||||
return self._mean()
|
||||
|
||||
def _quantile(self, value):
|
||||
raise NotImplementedError("quantile is not implemented")
|
||||
raise NotImplementedError("quantile is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def _call_quantile(self, value, name, **kwargs):
|
||||
with self._name_scope(name, values=[value]):
|
||||
@ -955,7 +986,8 @@ class Distribution(_BaseDistribution):
|
||||
return self._call_quantile(value, name)
|
||||
|
||||
def _variance(self):
|
||||
raise NotImplementedError("variance is not implemented")
|
||||
raise NotImplementedError("variance is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def variance(self, name="variance"):
|
||||
"""Variance.
|
||||
@ -979,11 +1011,15 @@ class Distribution(_BaseDistribution):
|
||||
with self._name_scope(name):
|
||||
try:
|
||||
return self._variance()
|
||||
except NotImplementedError:
|
||||
return math_ops.square(self._stddev())
|
||||
except NotImplementedError as original_exception:
|
||||
try:
|
||||
return math_ops.square(self._stddev())
|
||||
except NotImplementedError:
|
||||
raise original_exception
|
||||
|
||||
def _stddev(self):
|
||||
raise NotImplementedError("stddev is not implemented")
|
||||
raise NotImplementedError("stddev is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def stddev(self, name="stddev"):
|
||||
"""Standard deviation.
|
||||
@ -1008,11 +1044,15 @@ class Distribution(_BaseDistribution):
|
||||
with self._name_scope(name):
|
||||
try:
|
||||
return self._stddev()
|
||||
except NotImplementedError:
|
||||
return math_ops.sqrt(self._variance())
|
||||
except NotImplementedError as original_exception:
|
||||
try:
|
||||
return math_ops.sqrt(self._variance())
|
||||
except NotImplementedError:
|
||||
raise original_exception
|
||||
|
||||
def _covariance(self):
|
||||
raise NotImplementedError("covariance is not implemented")
|
||||
raise NotImplementedError("covariance is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def covariance(self, name="covariance"):
|
||||
"""Covariance.
|
||||
@ -1054,7 +1094,8 @@ class Distribution(_BaseDistribution):
|
||||
return self._covariance()
|
||||
|
||||
def _mode(self):
|
||||
raise NotImplementedError("mode is not implemented")
|
||||
raise NotImplementedError("mode is not implemented: {}".format(
|
||||
type(self).__name__))
|
||||
|
||||
def mode(self, name="mode"):
|
||||
"""Mode."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user