diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index e82d604d58a..66b43662fd5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -325,7 +325,7 @@ class FillLowerTriangularTest(tf.test.TestCase): def testCorrectlyMakesNoBatchLowerTril(self): with self.test_session(): - x = np.arange(9) + x = tf.convert_to_tensor(np.arange(9, dtype=np.float32)) expected = np.array( [[0., 0., 0.], [1., 2., 0.], @@ -333,6 +333,10 @@ class FillLowerTriangularTest(tf.test.TestCase): actual = distribution_util.fill_lower_triangular(x) self.assertAllEqual(expected.shape, actual.get_shape()) self.assertAllEqual(expected, actual.eval()) + self.assertAllEqual( + np.concatenate([np.ones(6, dtype=np.float32), + np.zeros(3, dtype=np.float32)]), + tf.gradients(distribution_util.fill_lower_triangular(x), x)[0].eval()) def testCorrectlyMakesBatchLowerTril(self): with self.test_session(): diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index ad5fa5b5aee..e27dcfe9b3f 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -435,15 +435,14 @@ def fill_lower_triangular(x, name="fill_lower_triangular"): """ with ops.name_scope(name, values=(x,)): x = ops.convert_to_tensor(x, name="x") - ndims = x.get_shape().ndims - if ndims is not None and x.get_shape()[-1].value is not None: + if (x.get_shape().ndims is not None and + x.get_shape()[-1].value is not None): d = x.get_shape()[-1].value # d = n^2/2 + n/2 implies n is: n = int(0.5 * (math.sqrt(1. + 8. * d) - 1.)) final_shape = x.get_shape()[:-1].concatenate( tensor_shape.TensorShape([n, n])) else: - ndims = array_ops.rank(x) d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32) # d = n^2/2 + n/2 implies n is: n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.), @@ -494,7 +493,12 @@ def fill_lower_triangular(x, name="fill_lower_triangular"): array_ops.tile([tril_ids], [m, 1])]) idx = array_ops.transpose(idx, [1, 2, 0]) - y = array_ops.gather_nd(y, idx) + if x.get_shape().ndims == 1: + # Prefer using gather because it has a gradient. + # We wrap the result in a list so downstream logic "just works." + y = [array_ops.gather(y[0, :], tril_ids)] + else: + y = array_ops.gather_nd(y, idx) y = array_ops.reshape(y, array_ops.concat(0, [batch_shape, [n, n]])) y.set_shape(y.get_shape().merge_with(final_shape))