From d943de372a989ca6bc44058e35ba9f26591b42b4 Mon Sep 17 00:00:00 2001 From: Christopher Suter Date: Thu, 14 Jun 2018 12:05:53 -0700 Subject: [PATCH] Support non-static shape in `tf.distributions.Categorical`. PiperOrigin-RevId: 200596358 --- .../python/kernel_tests/distributions/BUILD | 1 + .../distributions/categorical_test.py | 20 ++++++++++++++-- .../python/ops/distributions/categorical.py | 23 +++++++++---------- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD index cf2e8832fd5..985922245e9 100644 --- a/tensorflow/python/kernel_tests/distributions/BUILD +++ b/tensorflow/python/kernel_tests/distributions/BUILD @@ -93,6 +93,7 @@ cuda_py_test( size = "small", srcs = ["categorical_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//tensorflow/python/ops/distributions", "//third_party/py/numpy", "//tensorflow/python:array_ops", diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py index ca2358fe999..68b4ffdb58c 100644 --- a/tensorflow/python/kernel_tests/distributions/categorical_test.py +++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.framework import constant_op @@ -40,7 +41,7 @@ def make_categorical(batch_shape, num_classes, dtype=dtypes.int32): return categorical.Categorical(logits, dtype=dtype) -class CategoricalTest(test.TestCase): +class CategoricalTest(test.TestCase, parameterized.TestCase): def testP(self): p = [0.2, 0.8] @@ -131,7 +132,7 @@ class CategoricalTest(test.TestCase): with self.test_session(): self.assertAllClose(dist.prob(0).eval(), 0.2) - def testCDFWithDynamicEventShape(self): + def testCDFWithDynamicEventShapeKnownNdims(self): """Test that dynamically-sized events with unknown shape work.""" batch_size = 2 histograms = array_ops.placeholder(dtype=dtypes.float32, @@ -167,6 +168,21 @@ class CategoricalTest(test.TestCase): self.assertAllClose(actual_cdf_one, expected_cdf_one) self.assertAllClose(actual_cdf_two, expected_cdf_two) + @parameterized.named_parameters( + ("test1", [0, 1], [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]], [0.0, 1.0]), + ("test2", [2, 5], [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1], + [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]], [0.9, 0.88])) + def testCDFWithDynamicEventShapeUnknownNdims( + self, events, histograms, expected_cdf): + """Test that dynamically-sized events with unknown shape work.""" + event_ph = array_ops.placeholder_with_default(events, shape=None) + histograms_ph = array_ops.placeholder_with_default(histograms, shape=None) + dist = categorical.Categorical(probs=histograms_ph) + cdf_op = dist.cdf(event_ph) + + actual_cdf = self.evaluate(cdf_op) + self.assertAllClose(actual_cdf, expected_cdf) + def testCDFWithBatch(self): histograms = [[0.1, 0.2, 0.3, 0.25, 0.15], [0.0, 0.75, 0.2, 0.05, 0.0]] diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py index b88a0518b6d..dd25fce2ec8 100644 --- a/tensorflow/python/ops/distributions/categorical.py +++ b/tensorflow/python/ops/distributions/categorical.py @@ -32,12 +32,8 @@ from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.util.tf_export import tf_export -def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32): +def _broadcast_cat_event_and_params(event, params, base_dtype): """Broadcasts the event or distribution parameters.""" - if event.shape.ndims is None: - raise NotImplementedError( - "Cannot broadcast with an event tensor of unknown rank.") - if event.dtype.is_integer: pass elif event.dtype.is_floating: @@ -47,15 +43,18 @@ def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32): else: raise TypeError("`value` should have integer `dtype` or " "`self.dtype` ({})".format(base_dtype)) - - if params.get_shape()[:-1] == event.get_shape(): - params = params - else: - params *= array_ops.ones_like( - array_ops.expand_dims(event, -1), dtype=params.dtype) + shape_known_statically = ( + params.shape.ndims is not None and + params.shape[:-1].is_fully_defined() and + event.shape.is_fully_defined()) + if not shape_known_statically or params.shape[:-1] != event.shape: + params *= array_ops.ones_like(event[..., array_ops.newaxis], + dtype=params.dtype) params_shape = array_ops.shape(params)[:-1] event *= array_ops.ones(params_shape, dtype=event.dtype) - event.set_shape(tensor_shape.TensorShape(params.get_shape()[:-1])) + if params.shape.ndims is not None: + event.set_shape(tensor_shape.TensorShape(params.shape[:-1])) + return event, params