Support non-static shape in tf.distributions.Categorical.

PiperOrigin-RevId: 200596358
This commit is contained in:
Christopher Suter 2018-06-14 12:05:53 -07:00 committed by TensorFlower Gardener
parent 7ccf1937b8
commit d943de372a
3 changed files with 30 additions and 14 deletions

View File

@ -93,6 +93,7 @@ cuda_py_test(
size = "small",
srcs = ["categorical_test.py"],
additional_deps = [
"@absl_py//absl/testing:parameterized",
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import constant_op
@ -40,7 +41,7 @@ def make_categorical(batch_shape, num_classes, dtype=dtypes.int32):
return categorical.Categorical(logits, dtype=dtype)
class CategoricalTest(test.TestCase):
class CategoricalTest(test.TestCase, parameterized.TestCase):
def testP(self):
p = [0.2, 0.8]
@ -131,7 +132,7 @@ class CategoricalTest(test.TestCase):
with self.test_session():
self.assertAllClose(dist.prob(0).eval(), 0.2)
def testCDFWithDynamicEventShape(self):
def testCDFWithDynamicEventShapeKnownNdims(self):
"""Test that dynamically-sized events with unknown shape work."""
batch_size = 2
histograms = array_ops.placeholder(dtype=dtypes.float32,
@ -167,6 +168,21 @@ class CategoricalTest(test.TestCase):
self.assertAllClose(actual_cdf_one, expected_cdf_one)
self.assertAllClose(actual_cdf_two, expected_cdf_two)
@parameterized.named_parameters(
("test1", [0, 1], [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]], [0.0, 1.0]),
("test2", [2, 5], [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1],
[0.15, 0.2, 0.05, 0.35, 0.13, 0.12]], [0.9, 0.88]))
def testCDFWithDynamicEventShapeUnknownNdims(
self, events, histograms, expected_cdf):
"""Test that dynamically-sized events with unknown shape work."""
event_ph = array_ops.placeholder_with_default(events, shape=None)
histograms_ph = array_ops.placeholder_with_default(histograms, shape=None)
dist = categorical.Categorical(probs=histograms_ph)
cdf_op = dist.cdf(event_ph)
actual_cdf = self.evaluate(cdf_op)
self.assertAllClose(actual_cdf, expected_cdf)
def testCDFWithBatch(self):
histograms = [[0.1, 0.2, 0.3, 0.25, 0.15],
[0.0, 0.75, 0.2, 0.05, 0.0]]

View File

@ -32,12 +32,8 @@ from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32):
def _broadcast_cat_event_and_params(event, params, base_dtype):
"""Broadcasts the event or distribution parameters."""
if event.shape.ndims is None:
raise NotImplementedError(
"Cannot broadcast with an event tensor of unknown rank.")
if event.dtype.is_integer:
pass
elif event.dtype.is_floating:
@ -47,15 +43,18 @@ def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32):
else:
raise TypeError("`value` should have integer `dtype` or "
"`self.dtype` ({})".format(base_dtype))
if params.get_shape()[:-1] == event.get_shape():
params = params
else:
params *= array_ops.ones_like(
array_ops.expand_dims(event, -1), dtype=params.dtype)
shape_known_statically = (
params.shape.ndims is not None and
params.shape[:-1].is_fully_defined() and
event.shape.is_fully_defined())
if not shape_known_statically or params.shape[:-1] != event.shape:
params *= array_ops.ones_like(event[..., array_ops.newaxis],
dtype=params.dtype)
params_shape = array_ops.shape(params)[:-1]
event *= array_ops.ones(params_shape, dtype=event.dtype)
event.set_shape(tensor_shape.TensorShape(params.get_shape()[:-1]))
if params.shape.ndims is not None:
event.set_shape(tensor_shape.TensorShape(params.shape[:-1]))
return event, params