Have TensorFlow Distributions share name scopes across method calls.
END_PUBLIC *** Reason for rollback *** Roll forward, allowing distributions to have same names across objects. *** Original change description *** BEGIN_PUBLIC Automated g4 rollback of changelist 190728742 PiperOrigin-RevId: 193428925
This commit is contained in:
parent
8c66f22230
commit
427a458ae6
@ -190,11 +190,30 @@ class DistributionTest(test.TestCase):
|
||||
y = dist._set_sample_static_shape(x, sample_shape)
|
||||
self.assertTrue(y.get_shape().ndims is None)
|
||||
|
||||
def testNameScopeWorksCorrectly(self):
|
||||
x = tfd.Normal(loc=0., scale=1., name="x")
|
||||
x_duplicate = tfd.Normal(loc=0., scale=1., name="x")
|
||||
with ops.name_scope("y") as name:
|
||||
y = tfd.Bernoulli(logits=0., name=name)
|
||||
x_sample = x.sample(name="custom_sample")
|
||||
x_sample_duplicate = x.sample(name="custom_sample")
|
||||
x_log_prob = x.log_prob(0., name="custom_log_prob")
|
||||
x_duplicate_sample = x_duplicate.sample(name="custom_sample")
|
||||
|
||||
self.assertEqual(x.name, "x/")
|
||||
self.assertEqual(x_duplicate.name, "x_1/")
|
||||
self.assertEqual(y.name, "y/")
|
||||
self.assertTrue(x_sample.name.startswith("x/custom_sample"))
|
||||
self.assertTrue(x_sample_duplicate.name.startswith("x/custom_sample_1"))
|
||||
self.assertTrue(x_log_prob.name.startswith("x/custom_log_prob"))
|
||||
self.assertTrue(x_duplicate_sample.name.startswith(
|
||||
"x_1/custom_sample"))
|
||||
|
||||
def testStrWorksCorrectlyScalar(self):
|
||||
normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
|
||||
self.assertEqual(
|
||||
("tf.distributions.Normal("
|
||||
"\"Normal\", "
|
||||
"\"Normal/\", "
|
||||
"batch_shape=(), "
|
||||
"event_shape=(), "
|
||||
"dtype=float16)"), # Got the dtype right.
|
||||
@ -203,7 +222,7 @@ class DistributionTest(test.TestCase):
|
||||
chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
|
||||
self.assertEqual(
|
||||
("tf.distributions.Chi2("
|
||||
"\"silly\", " # What a silly name that is!
|
||||
"\"silly/\", " # What a silly name that is!
|
||||
"batch_shape=(2,), "
|
||||
"event_shape=(), "
|
||||
"dtype=float32)"),
|
||||
@ -211,7 +230,7 @@ class DistributionTest(test.TestCase):
|
||||
|
||||
exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
|
||||
self.assertEqual(
|
||||
("tf.distributions.Exponential(\"Exponential\", "
|
||||
("tf.distributions.Exponential(\"Exponential/\", "
|
||||
# No batch shape.
|
||||
"event_shape=(), "
|
||||
"dtype=float32)"),
|
||||
@ -222,7 +241,7 @@ class DistributionTest(test.TestCase):
|
||||
loc=np.zeros([2, 2]), name="MVN")
|
||||
self.assertEqual(
|
||||
("tf.distributions.MultivariateNormalDiag("
|
||||
"\"MVN\", "
|
||||
"\"MVN/\", "
|
||||
"batch_shape=(2,), "
|
||||
"event_shape=(2,), "
|
||||
"dtype=float64)"),
|
||||
@ -233,7 +252,7 @@ class DistributionTest(test.TestCase):
|
||||
name="MVN2")
|
||||
self.assertEqual(
|
||||
("tf.distributions.MultivariateNormalDiag("
|
||||
"\"MVN2\", "
|
||||
"\"MVN2/\", "
|
||||
"batch_shape=(?,), " # Partially known.
|
||||
"event_shape=(3,), "
|
||||
"dtype=float32)"),
|
||||
@ -243,7 +262,7 @@ class DistributionTest(test.TestCase):
|
||||
normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
|
||||
self.assertEqual(
|
||||
("<tf.distributions.Normal"
|
||||
" 'Normal'"
|
||||
" 'Normal/'"
|
||||
" batch_shape=()"
|
||||
" event_shape=()"
|
||||
" dtype=float16>"), # Got the dtype right.
|
||||
@ -252,7 +271,7 @@ class DistributionTest(test.TestCase):
|
||||
chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
|
||||
self.assertEqual(
|
||||
("<tf.distributions.Chi2"
|
||||
" 'silly'" # What a silly name that is!
|
||||
" 'silly/'" # What a silly name that is!
|
||||
" batch_shape=(2,)"
|
||||
" event_shape=()"
|
||||
" dtype=float32>"),
|
||||
@ -261,7 +280,7 @@ class DistributionTest(test.TestCase):
|
||||
exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
|
||||
self.assertEqual(
|
||||
("<tf.distributions.Exponential"
|
||||
" 'Exponential'"
|
||||
" 'Exponential/'"
|
||||
" batch_shape=<unknown>"
|
||||
" event_shape=()"
|
||||
" dtype=float32>"),
|
||||
@ -272,7 +291,7 @@ class DistributionTest(test.TestCase):
|
||||
loc=np.zeros([2, 2]), name="MVN")
|
||||
self.assertEqual(
|
||||
("<tf.distributions.MultivariateNormalDiag"
|
||||
" 'MVN'"
|
||||
" 'MVN/'"
|
||||
" batch_shape=(2,)"
|
||||
" event_shape=(2,)"
|
||||
" dtype=float64>"),
|
||||
@ -283,7 +302,7 @@ class DistributionTest(test.TestCase):
|
||||
name="MVN2")
|
||||
self.assertEqual(
|
||||
("<tf.distributions.MultivariateNormalDiag"
|
||||
" 'MVN2'"
|
||||
" 'MVN2/'"
|
||||
" batch_shape=(?,)" # Partially known.
|
||||
" event_shape=(3,)"
|
||||
" dtype=float32>"),
|
||||
|
@ -52,7 +52,7 @@ class MultivariateNormalFullCovarianceTest(test.TestCase):
|
||||
mu = [1., 2.]
|
||||
sigma = [[1., 0.], [0., 1.]]
|
||||
mvn = ds.MultivariateNormalFullCovariance(mu, sigma, name="Billy")
|
||||
self.assertEqual(mvn.name, "Billy")
|
||||
self.assertEqual(mvn.name, "Billy/")
|
||||
|
||||
def testDoesNotRaiseIfInitializedWithSymmetricMatrix(self):
|
||||
with self.test_session():
|
||||
|
@ -145,7 +145,7 @@ class Autoregressive(distribution_lib.Distribution):
|
||||
ValueError: if `num_steps < 1`.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope(name) as name:
|
||||
self._distribution_fn = distribution_fn
|
||||
self._sample0 = sample0
|
||||
self._distribution0 = (distribution_fn() if sample0 is None
|
||||
|
@ -164,7 +164,7 @@ class Binomial(distribution.Distribution):
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[total_count, logits, probs]):
|
||||
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
|
||||
self._total_count = self._maybe_assert_valid_total_count(
|
||||
ops.convert_to_tensor(total_count, name="total_count"),
|
||||
validate_args)
|
||||
|
@ -121,7 +121,7 @@ class Cauchy(distribution.Distribution):
|
||||
TypeError: if `loc` and `scale` have different `dtype`.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[loc, scale]):
|
||||
with ops.name_scope(name, values=[loc, scale]) as name:
|
||||
with ops.control_dependencies([check_ops.assert_positive(scale)]
|
||||
if validate_args else []):
|
||||
self._loc = array_ops.identity(loc, name="loc")
|
||||
|
@ -88,7 +88,7 @@ class Chi2(gamma.Gamma):
|
||||
# not true in the parent class "gamma." therefore, passing
|
||||
# allow_nan_stats=True
|
||||
# through to the parent class results in unnecessary asserts.
|
||||
with ops.name_scope(name, values=[df]):
|
||||
with ops.name_scope(name, values=[df]) as name:
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_positive(df),
|
||||
] if validate_args else []):
|
||||
@ -120,7 +120,7 @@ class Chi2WithAbsDf(Chi2):
|
||||
allow_nan_stats=True,
|
||||
name="Chi2WithAbsDf"):
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[df]):
|
||||
with ops.name_scope(name, values=[df]) as name:
|
||||
super(Chi2WithAbsDf, self).__init__(
|
||||
df=math_ops.floor(
|
||||
math_ops.abs(df, name="abs_df"),
|
||||
|
@ -87,7 +87,7 @@ class _BaseDeterministic(distribution.Distribution):
|
||||
ValueError: If `loc` is a scalar.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[loc, atol, rtol]):
|
||||
with ops.name_scope(name, values=[loc, atol, rtol]) as name:
|
||||
loc = ops.convert_to_tensor(loc, name="loc")
|
||||
if is_vector and validate_args:
|
||||
msg = "Argument loc must be at least rank 1."
|
||||
|
@ -86,7 +86,7 @@ class Geometric(distribution.Distribution):
|
||||
"""
|
||||
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[logits, probs]):
|
||||
with ops.name_scope(name, values=[logits, probs]) as name:
|
||||
self._logits, self._probs = distribution_util.get_logits_and_probs(
|
||||
logits, probs, validate_args=validate_args, name=name)
|
||||
|
||||
|
@ -125,7 +125,7 @@ class _Gumbel(distribution.Distribution):
|
||||
TypeError: if loc and scale are different dtypes.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[loc, scale]):
|
||||
with ops.name_scope(name, values=[loc, scale]) as name:
|
||||
with ops.control_dependencies([check_ops.assert_positive(scale)] if
|
||||
validate_args else []):
|
||||
self._loc = array_ops.identity(loc, name="loc")
|
||||
|
@ -106,7 +106,7 @@ class HalfNormal(distribution.Distribution):
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[scale]):
|
||||
with ops.name_scope(name, values=[scale]) as name:
|
||||
with ops.control_dependencies([check_ops.assert_positive(scale)] if
|
||||
validate_args else []):
|
||||
self._scale = array_ops.identity(scale, name="scale")
|
||||
|
@ -119,7 +119,7 @@ class Independent(distribution_lib.Distribution):
|
||||
parameters = locals()
|
||||
name = name or "Independent" + distribution.name
|
||||
self._distribution = distribution
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope(name) as name:
|
||||
if reinterpreted_batch_ndims is None:
|
||||
reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims(
|
||||
distribution)
|
||||
|
@ -126,7 +126,7 @@ class InverseGamma(distribution.Distribution):
|
||||
TypeError: if `concentration` and `rate` are different dtypes.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[concentration, rate]):
|
||||
with ops.name_scope(name, values=[concentration, rate]) as name:
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_positive(concentration),
|
||||
check_ops.assert_positive(rate),
|
||||
@ -281,7 +281,7 @@ class InverseGammaWithSoftplusConcentrationRate(InverseGamma):
|
||||
allow_nan_stats=True,
|
||||
name="InverseGammaWithSoftplusConcentrationRate"):
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[concentration, rate]):
|
||||
with ops.name_scope(name, values=[concentration, rate]) as name:
|
||||
super(InverseGammaWithSoftplusConcentrationRate, self).__init__(
|
||||
concentration=nn.softplus(concentration,
|
||||
name="softplus_concentration"),
|
||||
|
@ -151,6 +151,7 @@ class Kumaraswamy(transformed_distribution.TransformedDistribution):
|
||||
more of the statistic's batch members are undefined.
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
"""
|
||||
with ops.name_scope(name, values=[concentration1, concentration0]) as name:
|
||||
concentration1 = ops.convert_to_tensor(
|
||||
concentration1, name="concentration1")
|
||||
concentration0 = ops.convert_to_tensor(
|
||||
|
@ -120,7 +120,7 @@ class Logistic(distribution.Distribution):
|
||||
TypeError: if loc and scale are different dtypes.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[loc, scale]):
|
||||
with ops.name_scope(name, values=[loc, scale]) as name:
|
||||
with ops.control_dependencies([check_ops.assert_positive(scale)] if
|
||||
validate_args else []):
|
||||
self._loc = array_ops.identity(loc, name="loc")
|
||||
|
@ -145,7 +145,7 @@ class Mixture(distribution.Distribution):
|
||||
"none of the components provide a static number of ndims")
|
||||
|
||||
# Ensure that all batch and event ndims are consistent.
|
||||
with ops.name_scope(name, values=[cat.logits]):
|
||||
with ops.name_scope(name, values=[cat.logits]) as name:
|
||||
num_components = cat.event_size
|
||||
static_num_components = tensor_util.constant_value(num_components)
|
||||
if static_num_components is None:
|
||||
|
@ -131,7 +131,7 @@ class MixtureSameFamily(distribution.Distribution):
|
||||
`components_distribution` rightmost batch shape.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope(name) as name:
|
||||
self._mixture_distribution = mixture_distribution
|
||||
self._components_distribution = components_distribution
|
||||
self._runtime_assertions = []
|
||||
|
@ -194,7 +194,7 @@ class MultivariateNormalDiag(
|
||||
ValueError: if at most `scale_identity_multiplier` is specified.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope(name) as name:
|
||||
with ops.name_scope("init", values=[
|
||||
loc, scale_diag, scale_identity_multiplier]):
|
||||
# No need to validate_args while making diag_scale. The returned
|
||||
@ -225,7 +225,7 @@ class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag):
|
||||
allow_nan_stats=True,
|
||||
name="MultivariateNormalDiagWithSoftplusScale"):
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[scale_diag]):
|
||||
with ops.name_scope(name, values=[scale_diag]) as name:
|
||||
super(MultivariateNormalDiagWithSoftplusScale, self).__init__(
|
||||
loc=loc,
|
||||
scale_diag=nn.softplus(scale_diag),
|
||||
|
@ -218,7 +218,7 @@ class MultivariateNormalDiagPlusLowRank(
|
||||
parameters = locals()
|
||||
def _convert_to_tensor(x, name):
|
||||
return None if x is None else ops.convert_to_tensor(x, name=name)
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope(name) as name:
|
||||
with ops.name_scope("init", values=[
|
||||
loc, scale_diag, scale_identity_multiplier, scale_perturb_factor,
|
||||
scale_perturb_diag]):
|
||||
|
@ -159,7 +159,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
|
||||
parameters = locals()
|
||||
|
||||
# Convert the covariance_matrix up to a scale_tril and call MVNTriL.
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope(name) as name:
|
||||
with ops.name_scope("init", values=[loc, covariance_matrix]):
|
||||
if covariance_matrix is None:
|
||||
scale_tril = None
|
||||
|
@ -176,7 +176,7 @@ class MultivariateNormalLinearOperator(
|
||||
if not scale.dtype.is_floating:
|
||||
raise TypeError("`scale` parameter must have floating-point dtype.")
|
||||
|
||||
with ops.name_scope(name, values=[loc] + scale.graph_parents):
|
||||
with ops.name_scope(name, values=[loc] + scale.graph_parents) as name:
|
||||
# Since expand_dims doesn't preserve constant-ness, we obtain the
|
||||
# non-dynamic value if possible.
|
||||
loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc
|
||||
|
@ -184,7 +184,7 @@ class MultivariateNormalTriL(
|
||||
return None if x is None else ops.convert_to_tensor(x, name=name)
|
||||
if loc is None and scale_tril is None:
|
||||
raise ValueError("Must specify one or both of `loc`, `scale_tril`.")
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope(name) as name:
|
||||
with ops.name_scope("init", values=[loc, scale_tril]):
|
||||
loc = _convert_to_tensor(loc, name="loc")
|
||||
scale_tril = _convert_to_tensor(scale_tril, name="scale_tril")
|
||||
|
@ -91,7 +91,7 @@ class NegativeBinomial(distribution.Distribution):
|
||||
"""
|
||||
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[total_count, logits, probs]):
|
||||
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
|
||||
self._logits, self._probs = distribution_util.get_logits_and_probs(
|
||||
logits, probs, validate_args=validate_args, name=name)
|
||||
with ops.control_dependencies(
|
||||
|
@ -116,7 +116,7 @@ class OneHotCategorical(distribution.Distribution):
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[logits, probs]):
|
||||
with ops.name_scope(name, values=[logits, probs]) as name:
|
||||
self._logits, self._probs = distribution_util.get_logits_and_probs(
|
||||
name=name, logits=logits, probs=probs, validate_args=validate_args,
|
||||
multidimensional=True)
|
||||
|
@ -94,7 +94,7 @@ class Poisson(distribution.Distribution):
|
||||
TypeError: if `log_rate` is not a float-type.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[rate]):
|
||||
with ops.name_scope(name, values=[rate]) as name:
|
||||
if (rate is None) == (log_rate is None):
|
||||
raise ValueError("Must specify exactly one of `rate` and `log_rate`.")
|
||||
elif log_rate is None:
|
||||
|
@ -256,7 +256,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
||||
`dtype`.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[loc, scale]):
|
||||
with ops.name_scope(name, values=[loc, scale]) as name:
|
||||
if loc is not None:
|
||||
loc = ops.convert_to_tensor(loc, name="loc")
|
||||
if scale is not None:
|
||||
|
@ -217,7 +217,7 @@ class QuantizedDistribution(distributions.Distribution):
|
||||
values = (
|
||||
list(distribution.parameters.values()) +
|
||||
[low, high])
|
||||
with ops.name_scope(name, values=values):
|
||||
with ops.name_scope(name, values=values) as name:
|
||||
self._dist = distribution
|
||||
|
||||
if low is not None:
|
||||
|
@ -166,7 +166,7 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution):
|
||||
ValueError: If both `probs` and `logits` are passed, or if neither.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[logits, probs, temperature]):
|
||||
with ops.name_scope(name, values=[logits, probs, temperature]) as name:
|
||||
with ops.control_dependencies([check_ops.assert_positive(temperature)]
|
||||
if validate_args else []):
|
||||
self._temperature = array_ops.identity(temperature, name="temperature")
|
||||
|
@ -163,7 +163,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[logits, probs, temperature]):
|
||||
with ops.name_scope(name, values=[logits, probs, temperature]) as name:
|
||||
|
||||
self._logits, self._probs = distribution_util.get_logits_and_probs(
|
||||
name=name, logits=logits, probs=probs, validate_args=validate_args,
|
||||
|
@ -134,7 +134,8 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution):
|
||||
"""
|
||||
parameters = locals()
|
||||
|
||||
with ops.name_scope(name, values=[loc, scale, skewness, tailweight]):
|
||||
with ops.name_scope(name,
|
||||
values=[loc, scale, skewness, tailweight]) as name:
|
||||
loc = ops.convert_to_tensor(loc, name="loc")
|
||||
dtype = loc.dtype
|
||||
scale = ops.convert_to_tensor(scale, name="scale", dtype=dtype)
|
||||
|
@ -396,7 +396,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
||||
ValueError: if `not distribution.is_scalar_event`.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[mix_loc, temperature]):
|
||||
with ops.name_scope(name, values=[mix_loc, temperature]) as name:
|
||||
if not scale or len(scale) < 2:
|
||||
raise ValueError("Must specify list (or list-like object) of scale "
|
||||
"LinearOperators, one for each component with "
|
||||
|
@ -176,7 +176,7 @@ class VectorExponentialDiag(
|
||||
ValueError: if at most `scale_identity_multiplier` is specified.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope(name) as name:
|
||||
with ops.name_scope("init", values=[
|
||||
loc, scale_diag, scale_identity_multiplier]):
|
||||
# No need to validate_args while making diag_scale. The returned
|
||||
|
@ -181,7 +181,7 @@ class VectorExponentialLinearOperator(
|
||||
if not scale.dtype.is_floating:
|
||||
raise TypeError("`scale` parameter must have floating-point dtype.")
|
||||
|
||||
with ops.name_scope(name, values=[loc] + scale.graph_parents):
|
||||
with ops.name_scope(name, values=[loc] + scale.graph_parents) as name:
|
||||
# Since expand_dims doesn't preserve constant-ness, we obtain the
|
||||
# non-dynamic value if possible.
|
||||
loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc
|
||||
|
@ -169,7 +169,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution):
|
||||
name,
|
||||
values=[
|
||||
loc, scale_diag, scale_identity_multiplier, skewness, tailweight
|
||||
]):
|
||||
]) as name:
|
||||
loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc
|
||||
tailweight = 1. if tailweight is None else tailweight
|
||||
has_default_skewness = skewness is None
|
||||
|
@ -178,7 +178,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
|
||||
parameters = locals()
|
||||
graph_parents = [df, loc, scale_identity_multiplier, scale_diag,
|
||||
scale_tril, scale_perturb_factor, scale_perturb_diag]
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope(name) as name:
|
||||
with ops.name_scope("init", values=graph_parents):
|
||||
# The shape of the _VectorStudentT distribution is governed by the
|
||||
# relationship between df.batch_shape and affine.batch_shape. In
|
||||
|
@ -109,7 +109,7 @@ class _WishartLinearOperator(distribution.Distribution):
|
||||
"""
|
||||
parameters = locals()
|
||||
self._cholesky_input_output_matrices = cholesky_input_output_matrices
|
||||
with ops.name_scope(name) as ns:
|
||||
with ops.name_scope(name) as name:
|
||||
with ops.name_scope("init", values=[df, scale_operator]):
|
||||
if not scale_operator.dtype.is_floating:
|
||||
raise TypeError(
|
||||
@ -163,7 +163,7 @@ class _WishartLinearOperator(distribution.Distribution):
|
||||
parameters=parameters,
|
||||
graph_parents=([self._df, self._dimension] +
|
||||
self._scale_operator.graph_parents),
|
||||
name=ns)
|
||||
name=name)
|
||||
|
||||
@property
|
||||
def df(self):
|
||||
@ -531,7 +531,7 @@ class WishartCholesky(_WishartLinearOperator):
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[scale]):
|
||||
with ops.name_scope(name, values=[scale]) as name:
|
||||
with ops.name_scope("init", values=[scale]):
|
||||
scale = ops.convert_to_tensor(scale)
|
||||
if validate_args:
|
||||
@ -647,7 +647,7 @@ class WishartFull(_WishartLinearOperator):
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name) as ns:
|
||||
with ops.name_scope(name) as name:
|
||||
with ops.name_scope("init", values=[scale]):
|
||||
scale = ops.convert_to_tensor(scale)
|
||||
if validate_args:
|
||||
@ -666,5 +666,5 @@ class WishartFull(_WishartLinearOperator):
|
||||
cholesky_input_output_matrices=cholesky_input_output_matrices,
|
||||
validate_args=validate_args,
|
||||
allow_nan_stats=allow_nan_stats,
|
||||
name=ns)
|
||||
name=name)
|
||||
self._parameters = parameters
|
||||
|
@ -72,7 +72,7 @@ class Bernoulli(distribution.Distribution):
|
||||
ValueError: If p and logits are passed, or if neither are passed.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name):
|
||||
with ops.name_scope(name) as name:
|
||||
self._logits, self._probs = distribution_util.get_logits_and_probs(
|
||||
logits=logits,
|
||||
probs=probs,
|
||||
|
@ -151,7 +151,7 @@ class Beta(distribution.Distribution):
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[concentration1, concentration0]):
|
||||
with ops.name_scope(name, values=[concentration1, concentration0]) as name:
|
||||
self._concentration1 = self._maybe_assert_valid_concentration(
|
||||
ops.convert_to_tensor(concentration1, name="concentration1"),
|
||||
validate_args)
|
||||
@ -323,7 +323,7 @@ class BetaWithSoftplusConcentration(Beta):
|
||||
name="BetaWithSoftplusConcentration"):
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[concentration1,
|
||||
concentration0]) as ns:
|
||||
concentration0]) as name:
|
||||
super(BetaWithSoftplusConcentration, self).__init__(
|
||||
concentration1=nn.softplus(concentration1,
|
||||
name="softplus_concentration1"),
|
||||
@ -331,7 +331,7 @@ class BetaWithSoftplusConcentration(Beta):
|
||||
name="softplus_concentration0"),
|
||||
validate_args=validate_args,
|
||||
allow_nan_stats=allow_nan_stats,
|
||||
name=ns)
|
||||
name=name)
|
||||
self._parameters = parameters
|
||||
|
||||
|
||||
|
@ -183,7 +183,7 @@ class Categorical(distribution.Distribution):
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[logits, probs]):
|
||||
with ops.name_scope(name, values=[logits, probs]) as name:
|
||||
self._logits, self._probs = distribution_util.get_logits_and_probs(
|
||||
logits=logits,
|
||||
probs=probs,
|
||||
|
@ -155,7 +155,7 @@ class Dirichlet(distribution.Distribution):
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[concentration]):
|
||||
with ops.name_scope(name, values=[concentration]) as name:
|
||||
self._concentration = self._maybe_assert_valid_concentration(
|
||||
ops.convert_to_tensor(concentration, name="concentration"),
|
||||
validate_args)
|
||||
|
@ -192,7 +192,7 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[total_count, concentration]):
|
||||
with ops.name_scope(name, values=[total_count, concentration]) as name:
|
||||
# Broadcasting works because:
|
||||
# * The broadcasting convention is to prepend dimensions of size [1], and
|
||||
# we use the last dimension for the distribution, whereas
|
||||
|
@ -434,13 +434,17 @@ class Distribution(_BaseDistribution):
|
||||
for i, t in enumerate(graph_parents):
|
||||
if t is None or not tensor_util.is_tensor(t):
|
||||
raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
|
||||
if not name or name[-1] != "/": # `name` is not a name scope
|
||||
non_unique_name = name or type(self).__name__
|
||||
with ops.name_scope(non_unique_name) as name:
|
||||
pass
|
||||
self._dtype = dtype
|
||||
self._reparameterization_type = reparameterization_type
|
||||
self._allow_nan_stats = allow_nan_stats
|
||||
self._validate_args = validate_args
|
||||
self._parameters = parameters or {}
|
||||
self._graph_parents = graph_parents
|
||||
self._name = name or type(self).__name__
|
||||
self._name = name
|
||||
|
||||
@classmethod
|
||||
def param_shapes(cls, sample_shape, name="DistributionParamShapes"):
|
||||
|
@ -95,7 +95,7 @@ class Exponential(gamma.Gamma):
|
||||
# true in the parent class "Gamma." Therefore, passing
|
||||
# allow_nan_stats=True
|
||||
# through to the parent class results in unnecessary asserts.
|
||||
with ops.name_scope(name, values=[rate]):
|
||||
with ops.name_scope(name, values=[rate]) as name:
|
||||
self._rate = ops.convert_to_tensor(rate, name="rate")
|
||||
super(Exponential, self).__init__(
|
||||
concentration=array_ops.ones([], dtype=self._rate.dtype),
|
||||
@ -144,7 +144,7 @@ class ExponentialWithSoftplusRate(Exponential):
|
||||
allow_nan_stats=True,
|
||||
name="ExponentialWithSoftplusRate"):
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[rate]):
|
||||
with ops.name_scope(name, values=[rate]) as name:
|
||||
super(ExponentialWithSoftplusRate, self).__init__(
|
||||
rate=nn.softplus(rate, name="softplus_rate"),
|
||||
validate_args=validate_args,
|
||||
|
@ -127,7 +127,7 @@ class Gamma(distribution.Distribution):
|
||||
TypeError: if `concentration` and `rate` are different dtypes.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[concentration, rate]):
|
||||
with ops.name_scope(name, values=[concentration, rate]) as name:
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_positive(concentration),
|
||||
check_ops.assert_positive(rate),
|
||||
@ -262,7 +262,7 @@ class GammaWithSoftplusConcentrationRate(Gamma):
|
||||
allow_nan_stats=True,
|
||||
name="GammaWithSoftplusConcentrationRate"):
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[concentration, rate]):
|
||||
with ops.name_scope(name, values=[concentration, rate]) as name:
|
||||
super(GammaWithSoftplusConcentrationRate, self).__init__(
|
||||
concentration=nn.softplus(concentration,
|
||||
name="softplus_concentration"),
|
||||
|
@ -101,7 +101,7 @@ class Laplace(distribution.Distribution):
|
||||
TypeError: if `loc` and `scale` are of different dtype.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[loc, scale]):
|
||||
with ops.name_scope(name, values=[loc, scale]) as name:
|
||||
with ops.control_dependencies([check_ops.assert_positive(scale)] if
|
||||
validate_args else []):
|
||||
self._loc = array_ops.identity(loc, name="loc")
|
||||
@ -218,7 +218,7 @@ class LaplaceWithSoftplusScale(Laplace):
|
||||
allow_nan_stats=True,
|
||||
name="LaplaceWithSoftplusScale"):
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[loc, scale]):
|
||||
with ops.name_scope(name, values=[loc, scale]) as name:
|
||||
super(LaplaceWithSoftplusScale, self).__init__(
|
||||
loc=loc,
|
||||
scale=nn.softplus(scale, name="softplus_scale"),
|
||||
|
@ -183,7 +183,7 @@ class Multinomial(distribution.Distribution):
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[total_count, logits, probs]):
|
||||
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
|
||||
self._total_count = ops.convert_to_tensor(total_count, name="total_count")
|
||||
if validate_args:
|
||||
self._total_count = (
|
||||
|
@ -132,7 +132,7 @@ class Normal(distribution.Distribution):
|
||||
TypeError: if `loc` and `scale` have different `dtype`.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[loc, scale]):
|
||||
with ops.name_scope(name, values=[loc, scale]) as name:
|
||||
with ops.control_dependencies([check_ops.assert_positive(scale)] if
|
||||
validate_args else []):
|
||||
self._loc = array_ops.identity(loc, name="loc")
|
||||
@ -244,7 +244,7 @@ class NormalWithSoftplusScale(Normal):
|
||||
allow_nan_stats=True,
|
||||
name="NormalWithSoftplusScale"):
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[scale]):
|
||||
with ops.name_scope(name, values=[scale]) as name:
|
||||
super(NormalWithSoftplusScale, self).__init__(
|
||||
loc=loc,
|
||||
scale=nn.softplus(scale, name="softplus_scale"),
|
||||
|
@ -158,7 +158,7 @@ class StudentT(distribution.Distribution):
|
||||
TypeError: if loc and scale are different dtypes.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[df, loc, scale]):
|
||||
with ops.name_scope(name, values=[df, loc, scale]) as name:
|
||||
with ops.control_dependencies([check_ops.assert_positive(df)]
|
||||
if validate_args else []):
|
||||
self._df = array_ops.identity(df, name="df")
|
||||
@ -350,7 +350,7 @@ class StudentTWithAbsDfSoftplusScale(StudentT):
|
||||
allow_nan_stats=True,
|
||||
name="StudentTWithAbsDfSoftplusScale"):
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[df, scale]):
|
||||
with ops.name_scope(name, values=[df, scale]) as name:
|
||||
super(StudentTWithAbsDfSoftplusScale, self).__init__(
|
||||
df=math_ops.floor(math_ops.abs(df)),
|
||||
loc=loc,
|
||||
|
@ -257,7 +257,7 @@ class TransformedDistribution(distribution_lib.Distribution):
|
||||
parameters = locals()
|
||||
name = name or (("" if bijector is None else bijector.name) +
|
||||
distribution.name)
|
||||
with ops.name_scope(name, values=[event_shape, batch_shape]):
|
||||
with ops.name_scope(name, values=[event_shape, batch_shape]) as name:
|
||||
# For convenience we define some handy constants.
|
||||
self._zero = constant_op.constant(0, dtype=dtypes.int32, name="zero")
|
||||
self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty")
|
||||
|
@ -103,7 +103,7 @@ class Uniform(distribution.Distribution):
|
||||
InvalidArgumentError: if `low >= high` and `validate_args=False`.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[low, high]):
|
||||
with ops.name_scope(name, values=[low, high]) as name:
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_less(
|
||||
low, high, message="uniform not defined when low >= high.")
|
||||
|
Loading…
Reference in New Issue
Block a user