From a24cfc8d096dbee915190e10b16fe279fe90ee67 Mon Sep 17 00:00:00 2001 From: Brian Patton Date: Tue, 19 Mar 2019 13:20:21 -0700 Subject: [PATCH] RandomGamma should return an empty Tensor for empty alpha. Before this change, it returns an error "Input alpha should have non-zero element count" if empty `alpha` is presented with non-empty `shape`. PiperOrigin-RevId: 239255074 --- tensorflow/core/kernels/random_op.cc | 2 +- tensorflow/python/kernel_tests/random/random_gamma_test.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 996950b65f3..e39e5f2eb3b 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -173,7 +173,7 @@ class RandomGammaOp : public OpKernel { Tensor* samples_t = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t)); - if (num_samples == 0) return; + if (samples_shape.num_elements() == 0) return; using random::PhiloxRandom; diff --git a/tensorflow/python/kernel_tests/random/random_gamma_test.py b/tensorflow/python/kernel_tests/random/random_gamma_test.py index 5cc13f67777..2fbfdc0a963 100644 --- a/tensorflow/python/kernel_tests/random/random_gamma_test.py +++ b/tensorflow/python/kernel_tests/random/random_gamma_test.py @@ -53,6 +53,10 @@ class RandomGammaTest(test.TestCase): return func + def testEmptySamplingNoError(self): + self.evaluate(random_ops.random_gamma( + [5], alpha=np.ones([2, 0, 3]), beta=np.ones([3]), dtype=dtypes.float32)) + @test_util.run_deprecated_v1 def testMomentsFloat32(self): self._testMoments(dtypes.float32)