Support non-static shape in tf.distributions.Categorical
.
PiperOrigin-RevId: 200596358
This commit is contained in:
parent
7ccf1937b8
commit
d943de372a
@ -93,6 +93,7 @@ cuda_py_test(
|
||||
size = "small",
|
||||
srcs = ["categorical_test.py"],
|
||||
additional_deps = [
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -40,7 +41,7 @@ def make_categorical(batch_shape, num_classes, dtype=dtypes.int32):
|
||||
return categorical.Categorical(logits, dtype=dtype)
|
||||
|
||||
|
||||
class CategoricalTest(test.TestCase):
|
||||
class CategoricalTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testP(self):
|
||||
p = [0.2, 0.8]
|
||||
@ -131,7 +132,7 @@ class CategoricalTest(test.TestCase):
|
||||
with self.test_session():
|
||||
self.assertAllClose(dist.prob(0).eval(), 0.2)
|
||||
|
||||
def testCDFWithDynamicEventShape(self):
|
||||
def testCDFWithDynamicEventShapeKnownNdims(self):
|
||||
"""Test that dynamically-sized events with unknown shape work."""
|
||||
batch_size = 2
|
||||
histograms = array_ops.placeholder(dtype=dtypes.float32,
|
||||
@ -167,6 +168,21 @@ class CategoricalTest(test.TestCase):
|
||||
self.assertAllClose(actual_cdf_one, expected_cdf_one)
|
||||
self.assertAllClose(actual_cdf_two, expected_cdf_two)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("test1", [0, 1], [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]], [0.0, 1.0]),
|
||||
("test2", [2, 5], [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1],
|
||||
[0.15, 0.2, 0.05, 0.35, 0.13, 0.12]], [0.9, 0.88]))
|
||||
def testCDFWithDynamicEventShapeUnknownNdims(
|
||||
self, events, histograms, expected_cdf):
|
||||
"""Test that dynamically-sized events with unknown shape work."""
|
||||
event_ph = array_ops.placeholder_with_default(events, shape=None)
|
||||
histograms_ph = array_ops.placeholder_with_default(histograms, shape=None)
|
||||
dist = categorical.Categorical(probs=histograms_ph)
|
||||
cdf_op = dist.cdf(event_ph)
|
||||
|
||||
actual_cdf = self.evaluate(cdf_op)
|
||||
self.assertAllClose(actual_cdf, expected_cdf)
|
||||
|
||||
def testCDFWithBatch(self):
|
||||
histograms = [[0.1, 0.2, 0.3, 0.25, 0.15],
|
||||
[0.0, 0.75, 0.2, 0.05, 0.0]]
|
||||
|
@ -32,12 +32,8 @@ from tensorflow.python.ops.distributions import util as distribution_util
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32):
|
||||
def _broadcast_cat_event_and_params(event, params, base_dtype):
|
||||
"""Broadcasts the event or distribution parameters."""
|
||||
if event.shape.ndims is None:
|
||||
raise NotImplementedError(
|
||||
"Cannot broadcast with an event tensor of unknown rank.")
|
||||
|
||||
if event.dtype.is_integer:
|
||||
pass
|
||||
elif event.dtype.is_floating:
|
||||
@ -47,15 +43,18 @@ def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32):
|
||||
else:
|
||||
raise TypeError("`value` should have integer `dtype` or "
|
||||
"`self.dtype` ({})".format(base_dtype))
|
||||
|
||||
if params.get_shape()[:-1] == event.get_shape():
|
||||
params = params
|
||||
else:
|
||||
params *= array_ops.ones_like(
|
||||
array_ops.expand_dims(event, -1), dtype=params.dtype)
|
||||
shape_known_statically = (
|
||||
params.shape.ndims is not None and
|
||||
params.shape[:-1].is_fully_defined() and
|
||||
event.shape.is_fully_defined())
|
||||
if not shape_known_statically or params.shape[:-1] != event.shape:
|
||||
params *= array_ops.ones_like(event[..., array_ops.newaxis],
|
||||
dtype=params.dtype)
|
||||
params_shape = array_ops.shape(params)[:-1]
|
||||
event *= array_ops.ones(params_shape, dtype=event.dtype)
|
||||
event.set_shape(tensor_shape.TensorShape(params.get_shape()[:-1]))
|
||||
if params.shape.ndims is not None:
|
||||
event.set_shape(tensor_shape.TensorShape(params.shape[:-1]))
|
||||
|
||||
return event, params
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user